Skip to content

Commit b57ad2c

Browse files
committed
updates
1 parent 8638c2a commit b57ad2c

File tree

2 files changed

+20
-6
lines changed

2 files changed

+20
-6
lines changed

eval_protocol/pytest/integrations/openenv_trl_vllm.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -131,11 +131,15 @@ def rollout_func(prompts: List[str], trainer) -> Dict[str, List]:
131131
flush=True,
132132
)
133133

134+
# Import default logger for local tracing
135+
from eval_protocol.dataset_logger import default_logger
136+
134137
config = RolloutProcessorConfig(
135138
completion_params=base_params,
136139
mcp_config_path="",
137140
semaphore=asyncio.Semaphore(max_concurrency),
138141
steps=max_steps,
142+
logger=default_logger,
139143
)
140144

141145
# 3) Execute rollouts with VLLMPolicy

tests/pytest/test_openenv_echo_hub.py

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,11 @@
22
import os
33
import re
44

5+
56
from eval_protocol.models import EvaluationRow, Message, EvaluateResult
67
from eval_protocol.pytest import evaluation_test
78
from eval_protocol.pytest.openenv_rollout_processor import OpenEnvRolloutProcessor
89
import pytest
9-
import os
1010

1111
# Skip these integration-heavy tests on CI runners by default
1212
pytestmark = pytest.mark.skipif(os.getenv("CI") == "true", reason="Skip OpenEnv integration tests on CI")
@@ -45,14 +45,21 @@ def action_parser(response_text: str):
4545

4646
try:
4747
from envs.echo_env import EchoEnv # type: ignore
48+
4849
_HAS_ECHO = True
4950
except Exception:
5051
_HAS_ECHO = False
5152

5253

54+
# Inline test data
55+
ECHO_INLINE_DATA: List[Dict[str, Any]] = [
56+
{"id": "echo-1", "prompt": "hello"},
57+
{"id": "echo-2", "prompt": "test message"},
58+
]
59+
60+
5361
@evaluation_test( # type: ignore[misc]
54-
input_dataset=["tests/pytest/data/echo_dataset.jsonl"],
55-
dataset_adapter=echo_dataset_to_rows,
62+
input_rows=[echo_dataset_to_rows(ECHO_INLINE_DATA)],
5663
completion_params=[
5764
{
5865
"temperature": 0.0,
@@ -93,8 +100,13 @@ def test_openenv_echo_hub(row: EvaluationRow) -> EvaluationRow:
93100
# Preferred path: system sentinel "__ep_step_rewards__"
94101
step_rewards: List[float] = []
95102
for msg in row.messages or []:
96-
if msg.role == "system" and isinstance(msg.content, str) and msg.content.startswith("__ep_step_rewards__:"):
103+
if (
104+
msg.role == "system"
105+
and isinstance(msg.content, str)
106+
and msg.content.startswith("__ep_step_rewards__:")
107+
):
97108
import json as _json
109+
98110
payload = msg.content.split(":", 1)[1]
99111
step_rewards = _json.loads(payload) or []
100112
break
@@ -105,5 +117,3 @@ def test_openenv_echo_hub(row: EvaluationRow) -> EvaluationRow:
105117
score = max(0.0, min(1.0, total_reward))
106118
row.evaluation_result = EvaluateResult(score=score, reason=f"Echo total reward={total_reward:.2f}")
107119
return row
108-
109-

0 commit comments

Comments
 (0)