Skip to content

Commit 5a37648

Browse files
committed
2 parents 14bf073 + 936f4a5 commit 5a37648

21 files changed

+6732
-169
lines changed

eval_protocol/benchmarks/test_glm_streaming_compliance.py

Lines changed: 3551 additions & 0 deletions
Large diffs are not rendered by default.
Lines changed: 197 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,197 @@
1+
import logging
2+
import math
3+
import asyncio
4+
import inspect
5+
from typing import Any, Callable, Literal, Optional, Sequence, List
6+
7+
try:
8+
import chz
9+
from tinker_cookbook import renderers, tokenizer_utils
10+
from tinker_cookbook.rl.problem_env import ProblemGroupBuilder
11+
from tinker_cookbook.rl.types import RLDataset, RLDatasetBuilder
12+
from tinker_cookbook.eval.evaluators import SamplingClientEvaluator
13+
import tinker
14+
15+
TINKER_AVAILABLE = True
16+
except ImportError:
17+
TINKER_AVAILABLE = False
18+
# Dummy classes to avoid NameError when defining the class if imports fail
19+
# but we should probably raise an error if these are instantiated without dependencies
20+
RLDataset = object
21+
RLDatasetBuilder = object
22+
ProblemGroupBuilder = object
23+
SamplingClientEvaluator = object
24+
25+
from eval_protocol.adapters.base import BaseAdapter
26+
from eval_protocol.models import EvaluationRow
27+
from eval_protocol.pytest.types import RolloutProcessorConfig
28+
29+
logger = logging.getLogger(__name__)
30+
31+
32+
class EvalProtocolRLDataset(RLDataset):
33+
def __init__(
34+
self,
35+
adapter: BaseAdapter,
36+
row_converter: Callable[[Any, int], Optional[ProblemGroupBuilder]],
37+
batch_size: int,
38+
group_size: int,
39+
split: str = "train",
40+
limit: Optional[int] = None,
41+
):
42+
if not TINKER_AVAILABLE:
43+
raise ImportError("tinker-cookbook is required to use EvalProtocolRLDataset")
44+
45+
self.adapter = adapter
46+
self.row_converter = row_converter
47+
self.batch_size = batch_size
48+
self.group_size = group_size if split == "train" else 1
49+
50+
logger.info(f"Fetching {limit if limit else 'all'} rows from adapter for split {split}...")
51+
self.rows = list(self.adapter.get_evaluation_rows(split=split, limit=limit))
52+
logger.info(f"Loaded {len(self.rows)} rows.")
53+
54+
def get_batch(self, index: int) -> Sequence[ProblemGroupBuilder]:
55+
batch_start = index * self.batch_size
56+
batch_end = min((index + 1) * self.batch_size, len(self.rows))
57+
58+
batch_builders = []
59+
for i in range(batch_start, batch_end):
60+
row = self.rows[i]
61+
# row_converter should take the row and group_size and return a ProblemGroupBuilder
62+
builder = self.row_converter(row, self.group_size)
63+
if builder is not None:
64+
batch_builders.append(builder)
65+
66+
return batch_builders
67+
68+
def __len__(self) -> int:
69+
return math.ceil(len(self.rows) / self.batch_size)
70+
71+
72+
if TINKER_AVAILABLE:
73+
74+
class EvalProtocolEvaluator(SamplingClientEvaluator):
75+
def __init__(
76+
self,
77+
rows: List[EvaluationRow],
78+
eval_func: Callable[[EvaluationRow], EvaluationRow],
79+
rollout_processor_cls: Any,
80+
model_name: str,
81+
renderer_name: str,
82+
max_tokens: int = 512,
83+
temperature: float = 0.0,
84+
):
85+
self.rows = rows
86+
87+
# If the function is a dual_mode_wrapper (from @evaluation_test), unwrap it to get the raw function logic.
88+
# This avoids the overhead of the wrapper which is designed for pytest execution.
89+
if hasattr(eval_func, "_origin_func"):
90+
self.eval_func = eval_func._origin_func
91+
else:
92+
self.eval_func = eval_func
93+
94+
self.rollout_processor_cls = rollout_processor_cls
95+
self.model_name = model_name
96+
self.renderer_name = renderer_name
97+
self.max_tokens = max_tokens
98+
self.temperature = temperature
99+
100+
async def __call__(self, sampling_client: tinker.SamplingClient) -> dict[str, float]:
101+
processor = self.rollout_processor_cls(
102+
sampling_client=sampling_client, model_name=self.model_name, renderer_name=self.renderer_name
103+
)
104+
processor.setup()
105+
106+
# Config for rollout
107+
config = RolloutProcessorConfig(
108+
completion_params={
109+
"max_tokens": self.max_tokens,
110+
"temperature": self.temperature,
111+
},
112+
semaphore=asyncio.Semaphore(10), # Concurrency limit
113+
mcp_config_path="", # Not used
114+
steps=1,
115+
logger=None, # Optional logger
116+
kwargs={},
117+
)
118+
119+
# Run rollouts
120+
tasks = processor(self.rows, config)
121+
processed_rows = await asyncio.gather(*tasks)
122+
123+
# Score
124+
scores = []
125+
for row in processed_rows:
126+
# Call the function logic (sync or async)
127+
res = self.eval_func(row)
128+
129+
if inspect.isawaitable(res):
130+
scored_row = await res
131+
else:
132+
scored_row = res
133+
134+
if scored_row.evaluation_result and scored_row.evaluation_result.score is not None:
135+
scores.append(scored_row.evaluation_result.score)
136+
137+
mean_score = sum(scores) / len(scores) if scores else 0.0
138+
return {"accuracy": mean_score}
139+
140+
141+
def create_eval_protocol_dataset_builder(
142+
adapter_factory: Callable[[], BaseAdapter],
143+
row_converter: Callable[[Any, int, Any, Any], Optional[ProblemGroupBuilder]],
144+
convo_prefix_factory: Optional[Callable[[], list]] = None,
145+
train_limit: int = 1000,
146+
test_limit: int = 100,
147+
) -> type:
148+
"""
149+
Factory to create a specific RLDatasetBuilder class for a given adapter.
150+
"""
151+
if not TINKER_AVAILABLE:
152+
return object
153+
154+
@chz.chz
155+
class CustomBuilder(RLDatasetBuilder):
156+
batch_size: int
157+
model_name_for_tokenizer: str
158+
renderer_name: str
159+
group_size: int
160+
seed: int = 0
161+
162+
async def __call__(self) -> tuple[RLDataset, RLDataset]:
163+
tokenizer = tokenizer_utils.get_tokenizer(self.model_name_for_tokenizer)
164+
renderer = renderers.get_renderer(self.renderer_name, tokenizer=tokenizer)
165+
166+
# Create adapter
167+
adapter = adapter_factory()
168+
169+
# Get convo prefix if needed
170+
convo_prefix = convo_prefix_factory() if convo_prefix_factory else None
171+
172+
# Bind renderer and prefix to row converter if needed
173+
# We'll wrap the row_converter to inject renderer and prefix
174+
def bound_row_converter(row, g_size):
175+
return row_converter(row, g_size, renderer, convo_prefix)
176+
177+
train_ds = EvalProtocolRLDataset(
178+
adapter=adapter,
179+
row_converter=bound_row_converter,
180+
batch_size=self.batch_size,
181+
group_size=self.group_size,
182+
split="train",
183+
limit=train_limit,
184+
)
185+
186+
test_ds = EvalProtocolRLDataset(
187+
adapter=adapter,
188+
row_converter=bound_row_converter,
189+
batch_size=self.batch_size,
190+
group_size=self.group_size,
191+
split="test",
192+
limit=test_limit,
193+
)
194+
195+
return (train_ds, test_ds)
196+
197+
return CustomBuilder
Lines changed: 170 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,170 @@
1+
import asyncio
2+
import logging
3+
import os
4+
import time
5+
import traceback
6+
from typing import Any, Dict, List, Optional, Union
7+
8+
from eval_protocol.dataset_logger import default_logger
9+
from eval_protocol.models import EvaluationRow, Message
10+
from eval_protocol.pytest.rollout_processor import RolloutProcessor
11+
from eval_protocol.pytest.types import RolloutProcessorConfig
12+
13+
try:
14+
import tinker
15+
from tinker_cookbook import renderers, tokenizer_utils
16+
17+
TINKER_AVAILABLE = True
18+
except ImportError:
19+
TINKER_AVAILABLE = False
20+
21+
logger = logging.getLogger(__name__)
22+
23+
24+
class TinkerRolloutProcessor(RolloutProcessor):
25+
"""
26+
Rollout processor that uses a Tinker SamplingClient to generate responses.
27+
"""
28+
29+
def __init__(
30+
self,
31+
sampling_client: Optional[Any] = None,
32+
model_name: Optional[str] = None,
33+
renderer_name: str = "llama3",
34+
) -> None:
35+
"""
36+
Args:
37+
sampling_client: Pre-initialized tinker.SamplingClient. If None, one will be created using model_name.
38+
model_name: Name of the model to use (if sampling_client is None).
39+
renderer_name: Name of the renderer to use for formatting messages.
40+
"""
41+
if not TINKER_AVAILABLE:
42+
raise ImportError("tinker-cookbook is required to use TinkerRolloutProcessor")
43+
44+
self.sampling_client = sampling_client
45+
self.model_name = model_name
46+
self.renderer_name = renderer_name
47+
self.renderer = None
48+
self.tokenizer = None
49+
50+
def setup(self) -> None:
51+
"""Setup resources."""
52+
if self.sampling_client is None:
53+
if self.model_name is None:
54+
raise ValueError("Either sampling_client or model_name must be provided")
55+
56+
# Initialize Tinker service client
57+
# This assumes TINKER_API_KEY is set in env
58+
service_client = tinker.ServiceClient()
59+
self.sampling_client = service_client.create_sampling_client(base_model=self.model_name)
60+
61+
# Initialize tokenizer and renderer
62+
# We need the model name to get the correct tokenizer.
63+
# If sampling_client was provided without model_name, we might need to infer it or require it.
64+
if self.model_name:
65+
self.tokenizer = tokenizer_utils.get_tokenizer(self.model_name)
66+
else:
67+
# Fallback or try to get from client if possible?
68+
# For now, require model_name even if client is passed, or use a default
69+
# But usually we want the renderer to match the model.
70+
# Let's assume Llama-3 tokenizer if not specified for now or raise error
71+
raise ValueError("model_name is required to initialize tokenizer/renderer")
72+
73+
self.renderer = renderers.get_renderer(self.renderer_name, tokenizer=self.tokenizer)
74+
75+
def __call__(self, rows: List[EvaluationRow], config: RolloutProcessorConfig) -> List[asyncio.Task[EvaluationRow]]:
76+
"""Generate rollout tasks using Tinker."""
77+
78+
async def process_row(row: EvaluationRow) -> EvaluationRow:
79+
start_time = time.perf_counter()
80+
81+
if not row.messages:
82+
raise ValueError("Messages is empty")
83+
84+
# Prepare prompt using renderer
85+
# Convert messages to Tinker ModelInput
86+
# We need to convert EvaluationRow messages (standard format) to the renderer's expected input
87+
# The renderer expects a list of dicts or objects with role/content
88+
# eval_protocol Message objects have role/content attributes, which should work if renderer supports objects
89+
# checking renderer code... it typically iterates and accesses keys or attributes.
90+
# Let's convert to dicts to be safe.
91+
92+
convo = [
93+
{"role": m.role, "content": m.content}
94+
for m in row.messages
95+
if m.role in ["system", "user", "assistant"]
96+
]
97+
98+
prompt = self.renderer.build_generation_prompt(convo)
99+
100+
# Prepare sampling params
101+
# Map config.completion_params to Tinker SamplingParams
102+
# Default values matching standard configs
103+
max_tokens = config.completion_params.get("max_tokens", 512)
104+
temperature = config.completion_params.get("temperature", 1.0)
105+
top_p = config.completion_params.get("top_p", 1.0)
106+
top_k = config.completion_params.get("top_k", -1)
107+
108+
# Get stop sequences from renderer
109+
stop_sequences = self.renderer.get_stop_sequences()
110+
# Ensure stop_sequences is a list
111+
if stop_sequences is None:
112+
stop_sequences = []
113+
114+
sampling_params = tinker.SamplingParams(
115+
max_tokens=int(max_tokens),
116+
temperature=float(temperature),
117+
top_p=float(top_p),
118+
top_k=int(top_k),
119+
stop=stop_sequences,
120+
)
121+
122+
# Call Tinker API
123+
try:
124+
sample_result = await self.sampling_client.sample_async(
125+
prompt=prompt, num_samples=1, sampling_params=sampling_params
126+
)
127+
128+
# Parse response
129+
# renderer.parse_response returns (Message, bool)
130+
sampled_tokens = sample_result.sequences[0].tokens
131+
message, parse_success = self.renderer.parse_response(sampled_tokens)
132+
133+
if message:
134+
assistant_content = message["content"]
135+
else:
136+
assistant_content = ""
137+
138+
except Exception as e:
139+
# Try to extract more info if '0' is not helpful
140+
error_details = str(e)
141+
if error_details == "0":
142+
try:
143+
error_details = f"Code: {e.code}, Message: {getattr(e, 'message', 'unknown')}"
144+
except Exception as e2:
145+
pass
146+
# Log full traceback for debugging
147+
tb_str = traceback.format_exc()
148+
logger.error(f"Tinker sampling failed: {error_details}\nTraceback:\n{tb_str}")
149+
assistant_content = "" # Or handle error more gracefully
150+
# Could set status on row
151+
152+
# Update row
153+
new_messages = list(row.messages) + [Message(role="assistant", content=assistant_content)]
154+
row.messages = new_messages
155+
row.execution_metadata.duration_seconds = time.perf_counter() - start_time
156+
157+
# Log usage (approximate since Tinker might not return usage stats in same format)
158+
# We can count tokens ourselves
159+
row.execution_metadata.usage = None # Placeholder
160+
161+
default_logger.log(row)
162+
return row
163+
164+
semaphore = config.semaphore
165+
166+
async def _sem_wrapper(r: EvaluationRow) -> EvaluationRow:
167+
async with semaphore:
168+
return await process_row(r)
169+
170+
return [asyncio.create_task(_sem_wrapper(row)) for row in rows]

0 commit comments

Comments
 (0)