Skip to content

Commit 12fcae3

Browse files
committed
support for tokenids logprobs
1 parent c0c7c8e commit 12fcae3

File tree

2 files changed

+103
-0
lines changed

2 files changed

+103
-0
lines changed

eval_protocol/pytest/default_single_turn_rollout_process.py

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -170,6 +170,53 @@ async def process_row(row: EvaluationRow) -> EvaluationRow:
170170
)
171171
)
172172

173+
# Synchronously extract token_ids, routing_matrix, and logprobs from the provider response.
174+
try:
175+
token_ids = []
176+
routing_matrix = []
177+
logprobs_obj = getattr(response.choices[0], "logprobs", None)
178+
179+
if logprobs_obj is not None:
180+
if isinstance(logprobs_obj, dict):
181+
content = logprobs_obj.get("content", [])
182+
else:
183+
content = getattr(logprobs_obj, "content", [])
184+
185+
if isinstance(content, list):
186+
for item in content:
187+
if isinstance(item, dict):
188+
tid = item.get("token_id")
189+
rm = item.get("routing_matrix")
190+
else:
191+
tid = getattr(item, "token_id", None)
192+
rm = getattr(item, "routing_matrix", None)
193+
194+
if tid is not None:
195+
token_ids.append(tid)
196+
if rm is not None:
197+
routing_matrix.append(rm)
198+
199+
logger.info(
200+
"[SingleTurnRolloutProcessor] Extracted %d token_ids and %d routing_matrix entries from logprobs",
201+
len(token_ids),
202+
len(routing_matrix),
203+
)
204+
205+
# Store as 1D lists directly for SingleTurn (no step dimension needed)
206+
if token_ids or routing_matrix or logprobs_obj is not None:
207+
if not row.execution_metadata.extra:
208+
row.execution_metadata.extra = {}
209+
if token_ids:
210+
row.execution_metadata.extra["token_ids"] = token_ids
211+
if routing_matrix:
212+
row.execution_metadata.extra["routing_matrix"] = routing_matrix
213+
if logprobs_obj is not None:
214+
row.execution_metadata.extra["logprobs"] = logprobs_obj
215+
except Exception as e:
216+
logger.warning(
217+
"[SingleTurnRolloutProcessor] Failed to extract token_ids/routing_matrix/logprobs: %s", e
218+
)
219+
173220
row.messages = messages
174221

175222
row.execution_metadata.duration_seconds = time.perf_counter() - start_time

tests/pytest/test_pytest_input_messages.py

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,3 +22,59 @@ def test_input_messages_in_decorator(rows: List[EvaluationRow]) -> List[Evaluati
2222
for row in rows:
2323
row.evaluation_result = EvaluateResult(score=0.0, reason="Dummy evaluation result")
2424
return rows
25+
26+
27+
@pytest.mark.parametrize(
28+
"completion_params",
29+
[
30+
{
31+
"model": "fireworks_ai/accounts/fireworks/models/qwen3-30b-a3b",
32+
"logprobs": True,
33+
# "include_routing_matrix": True, # Requires --enable-moe-stats on server
34+
"temperature": 0.6,
35+
"max_tokens": 256,
36+
}
37+
],
38+
)
39+
@evaluation_test(
40+
input_messages=[
41+
[
42+
[
43+
Message(role="user", content="What is 2+2?"),
44+
]
45+
]
46+
],
47+
rollout_processor=SingleTurnRolloutProcessor(),
48+
mode="all",
49+
)
50+
def test_single_turn_with_logprobs_and_routing_matrix(rows: List[EvaluationRow]) -> List[EvaluationRow]:
51+
"""Test SingleTurnRolloutProcessor with logprobs and routing_matrix extraction."""
52+
for row in rows:
53+
# Check if extra metadata was extracted
54+
extra = row.execution_metadata.extra
55+
print("\n=== DEBUG: execution_metadata.extra ===")
56+
print(f"extra type: {type(extra)}")
57+
print(f"extra keys: {extra.keys() if isinstance(extra, dict) else 'N/A'}")
58+
59+
if isinstance(extra, dict):
60+
if "token_ids" in extra:
61+
token_ids = extra["token_ids"]
62+
print(f"token_ids: found, len={len(token_ids)}, first 10 ids={token_ids[:10]}")
63+
else:
64+
print("token_ids: NOT FOUND")
65+
66+
if "routing_matrix" in extra:
67+
routing_matrix = extra["routing_matrix"]
68+
print(f"routing_matrix: found, len={len(routing_matrix)}")
69+
else:
70+
print("routing_matrix: NOT FOUND")
71+
72+
if "logprobs" in extra:
73+
print("logprobs: found")
74+
else:
75+
print("logprobs: NOT FOUND")
76+
77+
print("=" * 50)
78+
79+
row.evaluation_result = EvaluateResult(score=1.0, reason="Test passed")
80+
return rows

0 commit comments

Comments
 (0)