-
Notifications
You must be signed in to change notification settings - Fork 16
Expand file tree
/
Copy pathtest_openenv_echo_hub.py
More file actions
120 lines (99 loc) · 3.91 KB
/
test_openenv_echo_hub.py
File metadata and controls
120 lines (99 loc) · 3.91 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
from typing import Any, Dict, List
import os
import re
from eval_protocol.models import EvaluationRow, Message, EvaluateResult
from eval_protocol.pytest import evaluation_test
from eval_protocol.pytest.openenv_rollout_processor import OpenEnvRolloutProcessor
import pytest
# Preferred import when using the monolithic `openenv` package
# Preferred import when using the monolithic `openenv` package
try:
from envs.echo_env import EchoEnv # type: ignore
except ImportError:
# Define dummy class to satisfy OpenEnvRolloutProcessor validation during collection
class EchoEnv: # type: ignore
pass
# Skip these integration-heavy tests on CI runners by default
pytestmark = pytest.mark.skipif(os.getenv("CI") == "true", reason="Skip OpenEnv integration tests on CI")
def echo_dataset_to_rows(data: List[Dict[str, Any]]) -> List[EvaluationRow]:
"""
Adapter: simple {"id": "...", "prompt": "..."} to EvaluationRows.
"""
rows: List[EvaluationRow] = []
for row in data:
prompt = str(row.get("prompt", "hello"))
rows.append(EvaluationRow(messages=[Message(role="user", content=prompt)]))
return rows
def prompt_builder(observation: Any, step: int, history: List[str]) -> str:
"""
Echo env is very simple; we just send a short instruction.
"""
return "Please repeat back the next message exactly."
def action_parser(response_text: str):
"""
Convert raw model response to EchoAction.
"""
try:
from envs.echo_env import EchoAction # type: ignore
except Exception:
pytest.skip("OpenEnv (openenv.envs.echo_env) is not installed; skipping Echo hub test.")
raise
text = response_text.strip() if isinstance(response_text, str) else ""
return EchoAction(message=text or "hello")
# try:
# from envs.echo_env import EchoEnv # type: ignore
# _HAS_ECHO = True
# except Exception:
# _HAS_ECHO = False
# Inline test data
ECHO_INLINE_DATA: List[Dict[str, Any]] = [
{"id": "echo-1", "prompt": "hello"},
{"id": "echo-2", "prompt": "test message"},
]
@evaluation_test( # type: ignore[misc]
input_rows=[echo_dataset_to_rows(ECHO_INLINE_DATA)],
completion_params=[
{
"temperature": 0.0,
"max_tokens": 16,
# Any working model with your API key; match other tests' default
"model": "fireworks_ai/accounts/fireworks/models/kimi-k2-instruct-0905",
}
],
num_runs=1,
max_concurrent_rollouts=2,
mode="pointwise",
rollout_processor=(
OpenEnvRolloutProcessor(
# Use HF Hub to launch the environment container automatically
env_client_cls=EchoEnv, # type: ignore
hub_repo_id=os.getenv("OPENENV_ECHO_REPO", "openenv/echo-env"),
# Simple prompt+parser above
prompt_builder=prompt_builder,
action_parser=action_parser,
# Keep defaults for timeouts/viewport/etc. (not relevant for echo)
timeout_ms=5000,
num_generations=1,
)
),
)
def test_openenv_echo_hub(row: EvaluationRow) -> EvaluationRow:
"""
Smoke test for Echo env via Hugging Face Hub (registry.hf.space/openenv-echo-env).
Extracts env rewards (from rollout policy extras) and sets evaluation_result.
"""
# Try to read rewards/usage left in execution metadata extra.
total_reward = 0.0
try:
extra = getattr(row.execution_metadata, "extra", None)
step_rewards: List[float] = []
if isinstance(extra, dict):
raw = extra.get("step_rewards") or []
step_rewards = [float(r) for r in raw]
print(f"Step rewards: {step_rewards}")
total_reward = float(sum(step_rewards)) if step_rewards else 0.0
except Exception:
total_reward = 0.0
score = max(0.0, min(1.0, total_reward))
row.evaluation_result = EvaluateResult(score=score, reason=f"Echo total reward={total_reward:.2f}")
return row