Skip to content

Commit a1a973e

Browse files
author
Shrey Modi
committed
finalll
1 parent 70f3d0e commit a1a973e

File tree

3 files changed

+51
-53
lines changed

3 files changed

+51
-53
lines changed

eval_protocol/mcp/execution/vllm_policy.py

Lines changed: 39 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,20 @@
11
"""
2-
VLLMPolicy - Policy for TRL's VLLMClient
2+
VLLMPolicy - Policy for TRL's VLLMClient or colocated vLLM LLM.
33
4-
Simple policy that calls TRL's vllm_client directly instead of going through LiteLLM.
5-
Works with `trl vllm-serve` endpoints.
4+
Thin adapter that turns Eval Protocol-style message lists into a single prompt,
5+
then calls either:
6+
7+
- TRL's VLLMClient (server mode), or
8+
- a colocated vLLM LLM instance (SamplingParams mode).
69
"""
710

11+
import logging
812
from typing import Any, Dict, List, Optional
913

1014

15+
logger = logging.getLogger(__name__)
16+
17+
1118
class VLLMPolicy:
1219
"""
1320
Policy that uses TRL's VLLMClient for generation.
@@ -52,7 +59,7 @@ async def _make_llm_call(
5259
tools: Optional[List] = None,
5360
) -> Dict[str, Any]:
5461
"""
55-
Make LLM call using TRL's VLLMClient.
62+
Make LLM call using TRL's VLLMClient or a colocated vLLM LLM.
5663
5764
Args:
5865
messages: List of message dicts with 'role' and 'content'
@@ -70,29 +77,29 @@ async def _make_llm_call(
7077
add_generation_prompt=True,
7178
tokenize=False,
7279
)
73-
print("\n[VLLMPolicy] ===== CHAT TEMPLATE APPLIED =====", flush=True)
74-
print(f"[VLLMPolicy] Input messages ({len(messages)} messages):", flush=True)
75-
for i, msg in enumerate(messages):
76-
content_preview = str(msg.get("content", ""))[:100]
77-
print(f" [{i}] {msg.get('role', '?')}: {content_preview}...", flush=True)
78-
print(f"[VLLMPolicy] Formatted prompt (length={len(prompt_text)}):", flush=True)
79-
print("[VLLMPolicy] Prompt preview (last 500 chars):", flush=True)
80-
print(f"{prompt_text[-500:]}", flush=True)
81-
print("[VLLMPolicy] ===================================", flush=True)
80+
logger.debug(
81+
"[VLLMPolicy] Chat template applied for %d messages (prompt length=%d)",
82+
len(messages),
83+
len(prompt_text),
84+
)
8285
except Exception as e:
83-
print(f"[VLLMPolicy] Warning: Failed to apply chat template: {e}", flush=True)
84-
# Fallback: simple concatenation
85-
prompt_text = "\n".join(f"{m['role']}: {m['content']}" for m in messages)
86+
logger.warning(
87+
"[VLLMPolicy] Failed to apply chat template: %s",
88+
e,
89+
exc_info=True,
90+
)
91+
# Fallback: simple concatenation (defensive .get access)
92+
prompt_text = "\n".join(f"{m.get('role', '?')}: {m.get('content', '')}" for m in messages)
8693
else:
8794
# No tokenizer: simple concatenation
88-
prompt_text = "\n".join(f"{m['role']}: {m['content']}" for m in messages)
95+
prompt_text = "\n".join(f"{m.get('role', '?')}: {m.get('content', '')}" for m in messages)
8996

9097
# Check if vllm_client is VLLMClient (server mode) or LLM (colocate mode)
9198
is_llm_object = hasattr(self.vllm_client, "llm_engine") # LLM has llm_engine
9299

93100
if is_llm_object:
94101
# Colocate mode: use SamplingParams
95-
print("[VLLMPolicy] Using vLLM LLM (colocate mode) with SamplingParams", flush=True)
102+
logger.debug("[VLLMPolicy] Using vLLM LLM (colocate mode) with SamplingParams")
96103
from vllm import SamplingParams
97104

98105
sampling_params = SamplingParams(
@@ -103,7 +110,7 @@ async def _make_llm_call(
103110
n=1,
104111
)
105112

106-
print("[VLLMPolicy] Calling LLM.generate()...", flush=True)
113+
logger.debug("[VLLMPolicy] Calling LLM.generate()")
107114
outputs = self.vllm_client.generate([prompt_text], sampling_params=sampling_params, use_tqdm=False)
108115

109116
# Extract from vLLM output format
@@ -116,7 +123,7 @@ async def _make_llm_call(
116123
}
117124
else:
118125
# Server mode: use VLLMClient with kwargs
119-
print("[VLLMPolicy] Using VLLMClient (server mode)", flush=True)
126+
logger.debug("[VLLMPolicy] Using VLLMClient (server mode)")
120127
vllm_params = {
121128
"temperature": self.temperature,
122129
"max_tokens": self.max_tokens,
@@ -126,7 +133,7 @@ async def _make_llm_call(
126133
}
127134
vllm_params.update(self.kwargs)
128135

129-
print("[VLLMPolicy] Calling vllm_client.generate()...", flush=True)
136+
logger.debug("[VLLMPolicy] Calling vllm_client.generate()")
130137
response = self.vllm_client.generate(
131138
prompts=[prompt_text],
132139
**vllm_params,
@@ -140,16 +147,18 @@ async def _make_llm_call(
140147
if self.tokenizer is not None:
141148
try:
142149
completion_text = self.tokenizer.decode(completion_ids, skip_special_tokens=True)
143-
print("\n[VLLMPolicy] ===== GENERATION RESULT =====", flush=True)
144-
print(f"[VLLMPolicy] Prompt tokens: {len(prompt_ids)}", flush=True)
145-
print(f"[VLLMPolicy] Completion tokens: {len(completion_ids)}", flush=True)
146-
print(f"[VLLMPolicy] FULL decoded completion ({len(completion_text)} chars):", flush=True)
147-
print("───────────────────────────────────────", flush=True)
148-
print(f"{completion_text}", flush=True)
149-
print("───────────────────────────────────────", flush=True)
150-
print("[VLLMPolicy] ==============================", flush=True)
150+
logger.debug(
151+
"[VLLMPolicy] Generation result: prompt_tokens=%d, completion_tokens=%d, completion_chars=%d",
152+
len(prompt_ids),
153+
len(completion_ids),
154+
len(completion_text),
155+
)
151156
except Exception as e:
152-
print(f"[VLLMPolicy] Warning: Failed to decode completion: {e}", flush=True)
157+
logger.warning(
158+
"[VLLMPolicy] Failed to decode completion: %s",
159+
e,
160+
exc_info=True,
161+
)
153162
completion_text = f"<decoded_error:{len(completion_ids)}_tokens>"
154163
else:
155164
# Fallback: just indicate number of tokens

eval_protocol/pytest/integrations/openenv_trl_vllm.py

Lines changed: 2 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -134,34 +134,18 @@ def rollout_func(prompts: List[str], trainer) -> Dict[str, List]:
134134
# 1) Build evaluation rows with rollout_id for logging
135135
import uuid
136136

137-
# Generate unique IDs for this batch
138-
def _gen_id():
139-
import random
140-
141-
words = [
142-
"quick",
143-
"lazy",
144-
"happy",
145-
"bright",
146-
"calm",
147-
"bold",
148-
"wise",
149-
"kind",
150-
]
151-
return f"{random.choice(words)}-{random.choice(words)}-{random.randint(10, 99)}"
152-
153137
evaluation_rows: List[EvaluationRow] = []
154138
for prompt_idx, prompt in enumerate(prompts):
155139
# One evaluation row per incoming prompt. GRPOTrainer will handle
156140
# grouping by `num_generations` at the trainer level; the custom
157141
# rollout_func must return one set of tokens per prompt.
158142
rollout_id = f"openenv_vllm_{uuid.uuid4().hex[:12]}"
159-
row_id = _gen_id()
160143

161144
row = EvaluationRow(
162145
messages=[Message(role="user", content=prompt)],
163146
input_metadata=InputMetadata(
164-
row_id=row_id, # Required for ep logs UI!
147+
# Let Eval Protocol generate a stable row_id from content.
148+
row_id=None,
165149
completion_params={},
166150
),
167151
)

eval_protocol/pytest/openenv_rollout_processor.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
import asyncio
1616
import logging
1717
import time
18+
from itertools import count
1819
from typing import List, Any, Dict, Callable, Generic, TypeVar, Optional, Type
1920
import json
2021

@@ -142,7 +143,9 @@ def __init__(
142143
self._viewport_height = viewport_height
143144
self._timeout_ms = timeout_ms
144145
self._num_generations = max(1, int(num_generations)) if num_generations else 1
145-
self._env_create_idx: int = 0
146+
# Counter used for task rotation when creating environments. Uses
147+
# itertools.count to avoid race conditions across concurrent rollouts.
148+
self._env_create_counter = count()
146149

147150
if self._tasks and not self._task_var:
148151
raise ValueError("task_var must be provided when tasks are configured.")
@@ -411,7 +414,9 @@ async def process_row(row: EvaluationRow) -> EvaluationRow:
411414
)
412415
row.execution_metadata.duration_seconds = time.perf_counter() - start_time
413416

414-
# Store rewards for TRL reward functions
417+
# Store per-step rewards in a sentinel system message so
418+
# evaluation tests and downstream integrations can reconstruct
419+
# episode rewards.
415420
sentinel = "__ep_step_rewards__:" + json.dumps(step_rewards)
416421
messages.append(Message(role="system", content=sentinel))
417422

@@ -469,7 +474,6 @@ async def process_row(row: EvaluationRow) -> EvaluationRow:
469474
env.close()
470475
logger.debug("[OpenEnvRolloutProcessor] Environment closed successfully")
471476
except Exception as close_err:
472-
print(f"[OpenEnvRolloutProcessor] Warning: Error closing environment: {close_err}", flush=True)
473477
logger.warning(
474478
"[OpenEnvRolloutProcessor] Error closing environment: %s",
475479
close_err,
@@ -534,8 +538,9 @@ def _generic_factory():
534538
# Select task for this env instance (if provided), grouped by num_generations
535539
selected_task: Optional[str] = None
536540
if self._tasks:
537-
idx = self._env_create_idx
538-
self._env_create_idx = idx + 1
541+
# Use a monotonic counter so concurrent environment creation
542+
# does not reuse the same index across rollouts.
543+
idx = next(self._env_create_counter)
539544
group = idx // max(1, self._num_generations)
540545
selected_task = self._tasks[group % len(self._tasks)]
541546
if not self._task_var:

0 commit comments

Comments
 (0)