Skip to content

Commit 19e7a98

Browse files
committed
Tinker example
1 parent ca8b2e8 commit 19e7a98

File tree

9 files changed

+799
-0
lines changed

9 files changed

+799
-0
lines changed
Lines changed: 232 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,232 @@
1+
import logging
2+
import math
3+
import asyncio
4+
from typing import Any, Callable, Literal, Optional, Sequence, List
5+
6+
try:
7+
import chz
8+
from tinker_cookbook import renderers, tokenizer_utils
9+
from tinker_cookbook.rl.problem_env import ProblemGroupBuilder
10+
from tinker_cookbook.rl.types import RLDataset, RLDatasetBuilder
11+
from tinker_cookbook.eval.evaluators import SamplingClientEvaluator
12+
import tinker
13+
14+
TINKER_AVAILABLE = True
15+
except ImportError:
16+
TINKER_AVAILABLE = False
17+
# Dummy classes to avoid NameError when defining the class if imports fail
18+
# but we should probably raise an error if these are instantiated without dependencies
19+
RLDataset = object
20+
RLDatasetBuilder = object
21+
ProblemGroupBuilder = object
22+
SamplingClientEvaluator = object
23+
24+
from eval_protocol.adapters.base import BaseAdapter
25+
from eval_protocol.models import EvaluationRow
26+
from eval_protocol.pytest.types import RolloutProcessorConfig
27+
28+
logger = logging.getLogger(__name__)
29+
30+
31+
class EvalProtocolRLDataset(RLDataset):
32+
def __init__(
33+
self,
34+
adapter: BaseAdapter,
35+
row_converter: Callable[[Any, int], Optional[ProblemGroupBuilder]],
36+
batch_size: int,
37+
group_size: int,
38+
split: str = "train",
39+
limit: Optional[int] = None,
40+
):
41+
if not TINKER_AVAILABLE:
42+
raise ImportError("tinker-cookbook is required to use EvalProtocolRLDataset")
43+
44+
self.adapter = adapter
45+
self.row_converter = row_converter
46+
self.batch_size = batch_size
47+
self.group_size = group_size if split == "train" else 1
48+
49+
logger.info(f"Fetching {limit if limit else 'all'} rows from adapter for split {split}...")
50+
self.rows = list(self.adapter.get_evaluation_rows(
51+
split=split,
52+
limit=limit
53+
))
54+
logger.info(f"Loaded {len(self.rows)} rows.")
55+
56+
def get_batch(self, index: int) -> Sequence[ProblemGroupBuilder]:
57+
batch_start = index * self.batch_size
58+
batch_end = min((index + 1) * self.batch_size, len(self.rows))
59+
60+
batch_builders = []
61+
for i in range(batch_start, batch_end):
62+
row = self.rows[i]
63+
# row_converter should take the row and group_size and return a ProblemGroupBuilder
64+
builder = self.row_converter(row, self.group_size)
65+
if builder is not None:
66+
batch_builders.append(builder)
67+
68+
return batch_builders
69+
70+
def __len__(self) -> int:
71+
return math.ceil(len(self.rows) / self.batch_size)
72+
73+
74+
if TINKER_AVAILABLE:
75+
class EvalProtocolEvaluator(SamplingClientEvaluator):
76+
"""
77+
Evaluator that uses Eval Protocol's logic to evaluate a model.
78+
"""
79+
def __init__(
80+
self,
81+
rows: List[EvaluationRow],
82+
scoring_fn: Callable[[EvaluationRow], EvaluationRow],
83+
rollout_processor_cls: Any, # TinkerRolloutProcessor class
84+
renderer_name: str,
85+
max_tokens: int = 512,
86+
temperature: float = 0.0,
87+
):
88+
self.rows = rows
89+
self.scoring_fn = scoring_fn
90+
self.rollout_processor_cls = rollout_processor_cls
91+
self.renderer_name = renderer_name
92+
self.max_tokens = max_tokens
93+
self.temperature = temperature
94+
95+
async def __call__(self, sampling_client: tinker.SamplingClient) -> dict[str, float]:
96+
# Create processor with the current sampling client
97+
processor = self.rollout_processor_cls(
98+
sampling_client=sampling_client,
99+
renderer_name=self.renderer_name,
100+
# model_name is not strictly needed if client is provided,
101+
# but processor might require it for tokenizer initialization.
102+
# We assume processor handles this gracefully or we pass a dummy if needed
103+
# but TinkerRolloutProcessor expects model_name for tokenizer.
104+
# We might need to pass model_name here or change processor to accept tokenizer directly.
105+
# For now, let's try passing None and see if processor can handle it
106+
# (it currently raises ValueError if model_name is missing).
107+
# We should probably pass the model name if available, but SamplingClientEvaluator interface doesn't provide it.
108+
# WORKAROUND: TinkerRolloutProcessor currently requires model_name to init tokenizer.
109+
# We should update TinkerRolloutProcessor to accept tokenizer directly or optional model_name.
110+
# For this specific implementation, we can try to access model_name from client? Unlikely.
111+
# Let's assume Llama-3 tokenizer by default inside processor if name missing?
112+
# OR we update EvalProtocolEvaluator to take model_name in init.
113+
model_name="meta-llama/Llama-3.1-8B-Instruct" # Default/Placeholder if not provided in init
114+
)
115+
116+
# We need to fix the model_name issue. Let's update __init__ to take model_name.
117+
pass
118+
return {} # Dummy for now, overwritten below
119+
120+
# Re-defining with model_name
121+
class EvalProtocolEvaluator(SamplingClientEvaluator):
122+
def __init__(
123+
self,
124+
rows: List[EvaluationRow],
125+
scoring_fn: Callable[[EvaluationRow], EvaluationRow],
126+
rollout_processor_cls: Any,
127+
model_name: str,
128+
renderer_name: str,
129+
max_tokens: int = 512,
130+
temperature: float = 0.0,
131+
):
132+
self.rows = rows
133+
self.scoring_fn = scoring_fn
134+
self.rollout_processor_cls = rollout_processor_cls
135+
self.model_name = model_name
136+
self.renderer_name = renderer_name
137+
self.max_tokens = max_tokens
138+
self.temperature = temperature
139+
140+
async def __call__(self, sampling_client: tinker.SamplingClient) -> dict[str, float]:
141+
processor = self.rollout_processor_cls(
142+
sampling_client=sampling_client,
143+
model_name=self.model_name,
144+
renderer_name=self.renderer_name
145+
)
146+
processor.setup()
147+
148+
# Config for rollout
149+
config = RolloutProcessorConfig(
150+
completion_params={
151+
"max_tokens": self.max_tokens,
152+
"temperature": self.temperature,
153+
},
154+
semaphore=asyncio.Semaphore(10), # Concurrency limit
155+
mcp_config_path="", # Not used
156+
steps=1,
157+
logger=None, # Optional logger
158+
kwargs={}
159+
)
160+
161+
# Run rollouts
162+
tasks = processor(self.rows, config)
163+
processed_rows = await asyncio.gather(*tasks)
164+
165+
# Score
166+
scores = []
167+
for row in processed_rows:
168+
scored_row = self.scoring_fn(row)
169+
if scored_row.evaluation_result and scored_row.evaluation_result.score is not None:
170+
scores.append(scored_row.evaluation_result.score)
171+
172+
mean_score = sum(scores) / len(scores) if scores else 0.0
173+
return {"accuracy": mean_score}
174+
175+
176+
def create_eval_protocol_dataset_builder(
177+
adapter_factory: Callable[[], BaseAdapter],
178+
row_converter: Callable[[Any, int, Any, Any], Optional[ProblemGroupBuilder]],
179+
convo_prefix_factory: Optional[Callable[[], list]] = None,
180+
train_limit: int = 1000,
181+
test_limit: int = 100,
182+
) -> type:
183+
"""
184+
Factory to create a specific RLDatasetBuilder class for a given adapter.
185+
"""
186+
if not TINKER_AVAILABLE:
187+
return object
188+
189+
@chz.chz
190+
class CustomBuilder(RLDatasetBuilder):
191+
batch_size: int
192+
model_name_for_tokenizer: str
193+
renderer_name: str
194+
group_size: int
195+
seed: int = 0
196+
197+
async def __call__(self) -> tuple[RLDataset, RLDataset]:
198+
tokenizer = tokenizer_utils.get_tokenizer(self.model_name_for_tokenizer)
199+
renderer = renderers.get_renderer(self.renderer_name, tokenizer=tokenizer)
200+
201+
# Create adapter
202+
adapter = adapter_factory()
203+
204+
# Get convo prefix if needed
205+
convo_prefix = convo_prefix_factory() if convo_prefix_factory else None
206+
207+
# Bind renderer and prefix to row converter if needed
208+
# We'll wrap the row_converter to inject renderer and prefix
209+
def bound_row_converter(row, g_size):
210+
return row_converter(row, g_size, renderer, convo_prefix)
211+
212+
train_ds = EvalProtocolRLDataset(
213+
adapter=adapter,
214+
row_converter=bound_row_converter,
215+
batch_size=self.batch_size,
216+
group_size=self.group_size,
217+
split="train",
218+
limit=train_limit
219+
)
220+
221+
test_ds = EvalProtocolRLDataset(
222+
adapter=adapter,
223+
row_converter=bound_row_converter,
224+
batch_size=self.batch_size,
225+
group_size=self.group_size,
226+
split="test",
227+
limit=test_limit
228+
)
229+
230+
return (train_ds, test_ds)
231+
232+
return CustomBuilder

0 commit comments

Comments
 (0)