Skip to content

Commit 8638c2a

Browse files
committed
comments
1 parent 7e71e03 commit 8638c2a

File tree

6 files changed

+760
-389
lines changed

6 files changed

+760
-389
lines changed

eval_protocol/pytest/integrations/openenv_trl_vllm.py

Lines changed: 50 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ def create_openenv_vllm_rollout_func(
3737
# Environment configuration
3838
env_client_cls: Optional[Type[Any]] = None,
3939
tasks: List[str] | None = None,
40+
task_var: Optional[str] = None,
4041
miniwob_url: str | None = None,
4142
docker_image: str = "browsergym-env:latest",
4243
env_base_url: Optional[str] = None,
@@ -66,52 +67,48 @@ def create_openenv_vllm_rollout_func(
6667
The environment side is configured via ``env_client_cls`` and the BrowserGym
6768
parameters (``tasks``, ``miniwob_url``, ``docker_image``, etc.).
6869
"""
69-
print(f"\n{'='*80}", flush=True)
70-
print(f"[openenv_trl_vllm] create_openenv_vllm_rollout_func() CALLED", flush=True)
70+
print(f"\n{'=' * 80}", flush=True)
71+
print("[openenv_trl_vllm] create_openenv_vllm_rollout_func() CALLED", flush=True)
7172
print(f" vllm_base_url: {vllm_base_url}", flush=True)
7273
print(f" vllm_model: {vllm_model}", flush=True)
7374
print(f" tasks: {tasks}", flush=True)
7475
print(f" max_steps: {max_steps}", flush=True)
75-
print(f"{'='*80}", flush=True)
76+
print(f"{'=' * 80}", flush=True)
7677
sys.stdout.flush()
77-
78+
7879
# Import VLLMPolicy
7980
from eval_protocol.mcp.execution.vllm_policy import VLLMPolicy
8081

8182
# Global-ish task rotation offset across rollout_func calls.
8283
# This lets us rotate tasks between GRPO steps instead of always
8384
# starting from tasks[0] when a new OpenEnvRolloutProcessor is created.
8485
task_cycle_index: int = 0
85-
86+
8687
def rollout_func(prompts: List[str], trainer) -> Dict[str, List]:
8788
"""Execute rollouts via OpenEnv + vLLM and return GRPO-compatible results."""
8889
print("\n[OpenEnvVLLM] rollout_func called", flush=True)
89-
90+
9091
# Extract args from trainer
9192
args = trainer.args
9293
processing_class = trainer.processing_class
93-
94+
9495
num_generations = getattr(args, "num_generations", 8)
9596
print(
96-
f"[OpenEnvVLLM] Received {len(prompts)} prompts, "
97-
f"{num_generations} generations each",
97+
f"[OpenEnvVLLM] Received {len(prompts)} prompts, {num_generations} generations each",
9898
flush=True,
9999
)
100-
100+
101101
# 1) Build evaluation rows
102102
evaluation_rows: List[EvaluationRow] = []
103103
for prompt in prompts:
104104
for gen_idx in range(num_generations):
105-
evaluation_rows.append(
106-
EvaluationRow(
107-
messages=[Message(role="user", content=prompt)],
108-
input_metadata=InputMetadata(
109-
completion_params={},
110-
extra={"generation_idx": gen_idx}
111-
),
112-
)
105+
row = EvaluationRow(
106+
messages=[Message(role="user", content=prompt)],
107+
input_metadata=InputMetadata(completion_params={}),
113108
)
114-
109+
row.input_metadata.generation_idx = gen_idx # type: ignore[attr-defined]
110+
evaluation_rows.append(row)
111+
115112
# 2) Build processor config with VLLMPolicy
116113
# We'll pass trainer.vllm_client to VLLMPolicy
117114
base_params: Dict[str, Any] = {
@@ -121,37 +118,33 @@ def rollout_func(prompts: List[str], trainer) -> Dict[str, List]:
121118
}
122119
if completion_params:
123120
base_params.update(completion_params)
124-
121+
125122
print(
126-
f"[OpenEnvVLLM] Temperature={base_params['temperature']}, "
127-
f"max_tokens={base_params['max_tokens']}",
123+
f"[OpenEnvVLLM] Temperature={base_params['temperature']}, max_tokens={base_params['max_tokens']}",
128124
flush=True,
129125
)
130126
print("[OpenEnvVLLM] Using TRL VLLMClient from trainer", flush=True)
131-
132-
max_concurrency = concurrency if concurrency is not None else getattr(
133-
args, "per_device_train_batch_size", 1
134-
)
127+
128+
max_concurrency = concurrency if concurrency is not None else getattr(args, "per_device_train_batch_size", 1)
135129
print(
136-
f"[OpenEnvVLLM] Max concurrency={max_concurrency}, "
137-
f"max_steps={max_steps}",
130+
f"[OpenEnvVLLM] Max concurrency={max_concurrency}, max_steps={max_steps}",
138131
flush=True,
139132
)
140-
133+
141134
config = RolloutProcessorConfig(
142135
completion_params=base_params,
143136
mcp_config_path="",
144137
semaphore=asyncio.Semaphore(max_concurrency),
145138
steps=max_steps,
146139
)
147-
140+
148141
# 3) Execute rollouts with VLLMPolicy
149142
print(
150143
f"[OpenEnvVLLM] Instantiating processor: "
151144
f"{processor_cls.__name__ if processor_cls else 'OpenEnvRolloutProcessor'}",
152145
flush=True,
153146
)
154-
147+
155148
# Create policy factory that uses trainer's vllm_client
156149
def vllm_policy_factory(model, temperature, max_tokens, base_url=None, **kwargs):
157150
"""Factory that creates VLLMPolicy using trainer's vllm_client."""
@@ -164,7 +157,7 @@ def vllm_policy_factory(model, temperature, max_tokens, base_url=None, **kwargs)
164157
top_k=kwargs.get("top_k"),
165158
**kwargs,
166159
)
167-
160+
168161
Processor = processor_cls or OpenEnvRolloutProcessor
169162
_kwargs: Dict[str, Any] = dict(processor_kwargs or {})
170163
_kwargs.setdefault("env_factory", env_factory)
@@ -187,6 +180,7 @@ def vllm_policy_factory(model, temperature, max_tokens, base_url=None, **kwargs)
187180
flush=True,
188181
)
189182
_kwargs.setdefault("tasks", rotated_tasks)
183+
_kwargs.setdefault("task_var", task_var)
190184

191185
_kwargs.setdefault("miniwob_url", miniwob_url)
192186
_kwargs.setdefault("docker_image", docker_image)
@@ -202,48 +196,49 @@ def vllm_policy_factory(model, temperature, max_tokens, base_url=None, **kwargs)
202196
_kwargs.setdefault("viewport_height", viewport_height)
203197
_kwargs.setdefault("timeout_ms", timeout_ms)
204198
_kwargs.setdefault("num_generations", num_generations)
205-
199+
206200
processor = Processor(**_kwargs)
207-
print(f"[OpenEnvVLLM] Processor instantiated successfully", flush=True)
208-
201+
print("[OpenEnvVLLM] Processor instantiated successfully", flush=True)
202+
209203
loop = asyncio.new_event_loop()
210204
asyncio.set_event_loop(loop)
211205
try:
206+
212207
async def _run_all():
213208
tasks_list = processor(evaluation_rows, config)
214209
return await asyncio.gather(*tasks_list)
215-
210+
216211
completed_rows = loop.run_until_complete(_run_all())
217212
print(
218213
f"[OpenEnvVLLM] All rollouts completed: {len(completed_rows)} results",
219214
flush=True,
220215
)
221216
finally:
222217
loop.close()
223-
218+
224219
# 4) Convert to Wordle-style format (no splitting)
225220
# Each completed_row is one rollout with multiple turns
226221
# We .extend() tokens across turns, then .append() per rollout
227222
print(
228223
f"[OpenEnvVLLM] Converting {len(completed_rows)} rollouts to TRL format",
229224
flush=True,
230225
)
231-
226+
232227
tokenizer = getattr(processing_class, "tokenizer", None) or processing_class
233228
encode_fn = getattr(tokenizer, "encode", None)
234-
229+
235230
episode_prompt_ids: List[List[int]] = []
236231
episode_completion_ids: List[List[int]] = []
237232
episode_logprobs: List[List[float]] = []
238233
step_rewards_all: List[List[float]] = []
239-
234+
240235
for idx, row in enumerate(completed_rows):
241236
# Accumulate tokens across all turns in this rollout
242237
prompt_ids: List[int] = [] # .extend() for each turn
243238
completion_ids: List[int] = [] # .extend() for each turn
244239
logprobs: List[float] = [] # .extend() for each turn
245240
rewards: List[float] = []
246-
241+
247242
# Go through all messages and accumulate tokens
248243
for msg in row.messages:
249244
if msg.role == "user":
@@ -259,50 +254,50 @@ async def _run_all():
259254
content = msg.content or ""
260255
if isinstance(content, str) and content.startswith("__ep_step_rewards__:"):
261256
import json
257+
262258
payload = content.split(":", 1)[1]
263259
rewards = json.loads(payload) or []
264260
except Exception:
265261
pass
266-
267-
# Fallback for rewards
268-
if not rewards and hasattr(row.execution_metadata, "extra"):
262+
263+
# Fallback for rewards (if extra field exists via model_config extra="allow")
264+
if not rewards:
269265
try:
270-
rewards = row.execution_metadata.extra.get("step_rewards", []) or []
266+
extra = getattr(row.execution_metadata, "extra", None)
267+
if isinstance(extra, dict):
268+
rewards = extra.get("step_rewards", []) or []
271269
except Exception:
272270
pass
273-
271+
274272
# Append accumulated tokens for this episode
275273
episode_prompt_ids.append(prompt_ids if prompt_ids else [0])
276274
episode_completion_ids.append(completion_ids if completion_ids else [0])
277275
episode_logprobs.append(logprobs if logprobs else [0.0])
278276
step_rewards_all.append(rewards if rewards else [0.0])
279-
277+
280278
total_reward = sum(sum(r) for r in step_rewards_all)
281279
avg_reward = total_reward / len(step_rewards_all) if step_rewards_all else 0.0
282280
print(
283281
f"[OpenEnvVLLM] Total reward={total_reward:.2f}, Avg reward={avg_reward:.2f}",
284282
flush=True,
285283
)
286-
print(
287-
f"[OpenEnvVLLM] Returning {len(episode_prompt_ids)} episodes", flush=True
288-
)
284+
print(f"[OpenEnvVLLM] Returning {len(episode_prompt_ids)} episodes", flush=True)
289285
sys.stdout.flush()
290-
286+
291287
# Return in Wordle format
292288
# Tokens: 2D arrays (accumulate across turns, one list per episode)
293289
# Rewards: 1D arrays (one scalar per episode)
294290
total_rewards = [sum(r) for r in step_rewards_all] # Sum step rewards per episode
295-
291+
296292
print(f"[OpenEnvVLLM] Episode rewards: {total_rewards}", flush=True)
297-
293+
298294
return {
299295
"prompt_ids": episode_prompt_ids, # List[List[int]] - tokens per episode
300296
"completion_ids": episode_completion_ids, # List[List[int]] - tokens per episode
301297
"logprobs": episode_logprobs, # List[List[float]] - logprobs per episode
302298
"step_rewards": total_rewards, # List[float] - total reward per episode (1D!)
303299
}
304-
300+
305301
print(f"[openenv_trl_vllm] Returning rollout_func (type={type(rollout_func)})", flush=True)
306302
sys.stdout.flush()
307303
return rollout_func
308-

0 commit comments

Comments
 (0)