-
Notifications
You must be signed in to change notification settings - Fork 16
Expand file tree
/
Copy pathtinker_cookbook.py
More file actions
197 lines (162 loc) · 6.92 KB
/
tinker_cookbook.py
File metadata and controls
197 lines (162 loc) · 6.92 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
import logging
import math
import asyncio
import inspect
from typing import Any, Callable, Literal, Optional, Sequence, List
try:
import chz
from tinker_cookbook import renderers, tokenizer_utils
from tinker_cookbook.rl.problem_env import ProblemGroupBuilder
from tinker_cookbook.rl.types import RLDataset, RLDatasetBuilder
from tinker_cookbook.eval.evaluators import SamplingClientEvaluator
import tinker
TINKER_AVAILABLE = True
except ImportError:
TINKER_AVAILABLE = False
# Dummy classes to avoid NameError when defining the class if imports fail
# but we should probably raise an error if these are instantiated without dependencies
RLDataset = object
RLDatasetBuilder = object
ProblemGroupBuilder = object
SamplingClientEvaluator = object
from eval_protocol.adapters.base import BaseAdapter
from eval_protocol.models import EvaluationRow
from eval_protocol.pytest.types import RolloutProcessorConfig
logger = logging.getLogger(__name__)
class EvalProtocolRLDataset(RLDataset):
def __init__(
self,
adapter: BaseAdapter,
row_converter: Callable[[Any, int], Optional[ProblemGroupBuilder]],
batch_size: int,
group_size: int,
split: str = "train",
limit: Optional[int] = None,
):
if not TINKER_AVAILABLE:
raise ImportError("tinker-cookbook is required to use EvalProtocolRLDataset")
self.adapter = adapter
self.row_converter = row_converter
self.batch_size = batch_size
self.group_size = group_size if split == "train" else 1
logger.info(f"Fetching {limit if limit else 'all'} rows from adapter for split {split}...")
self.rows = list(self.adapter.get_evaluation_rows(split=split, limit=limit))
logger.info(f"Loaded {len(self.rows)} rows.")
def get_batch(self, index: int) -> Sequence[ProblemGroupBuilder]:
batch_start = index * self.batch_size
batch_end = min((index + 1) * self.batch_size, len(self.rows))
batch_builders = []
for i in range(batch_start, batch_end):
row = self.rows[i]
# row_converter should take the row and group_size and return a ProblemGroupBuilder
builder = self.row_converter(row, self.group_size)
if builder is not None:
batch_builders.append(builder)
return batch_builders
def __len__(self) -> int:
return math.ceil(len(self.rows) / self.batch_size)
if TINKER_AVAILABLE:
class EvalProtocolEvaluator(SamplingClientEvaluator):
def __init__(
self,
rows: List[EvaluationRow],
eval_func: Callable[[EvaluationRow], EvaluationRow],
rollout_processor_cls: Any,
model_name: str,
renderer_name: str,
max_tokens: int = 512,
temperature: float = 0.0,
):
self.rows = rows
# If the function is a dual_mode_wrapper (from @evaluation_test), unwrap it to get the raw function logic.
# This avoids the overhead of the wrapper which is designed for pytest execution.
if hasattr(eval_func, "_origin_func"):
self.eval_func = eval_func._origin_func
else:
self.eval_func = eval_func
self.rollout_processor_cls = rollout_processor_cls
self.model_name = model_name
self.renderer_name = renderer_name
self.max_tokens = max_tokens
self.temperature = temperature
async def __call__(self, sampling_client: tinker.SamplingClient) -> dict[str, float]:
processor = self.rollout_processor_cls(
sampling_client=sampling_client, model_name=self.model_name, renderer_name=self.renderer_name
)
processor.setup()
# Config for rollout
config = RolloutProcessorConfig(
completion_params={
"max_tokens": self.max_tokens,
"temperature": self.temperature,
},
semaphore=asyncio.Semaphore(10), # Concurrency limit
mcp_config_path="", # Not used
steps=1,
logger=None, # Optional logger
kwargs={},
)
# Run rollouts
tasks = processor(self.rows, config)
processed_rows = await asyncio.gather(*tasks)
# Score
scores = []
for row in processed_rows:
# Call the function logic (sync or async)
res = self.eval_func(row)
if inspect.isawaitable(res):
scored_row = await res
else:
scored_row = res
if scored_row.evaluation_result and scored_row.evaluation_result.score is not None:
scores.append(scored_row.evaluation_result.score)
mean_score = sum(scores) / len(scores) if scores else 0.0
return {"accuracy": mean_score}
def create_eval_protocol_dataset_builder(
adapter_factory: Callable[[], BaseAdapter],
row_converter: Callable[[Any, int, Any, Any], Optional[ProblemGroupBuilder]],
convo_prefix_factory: Optional[Callable[[], list]] = None,
train_limit: int = 1000,
test_limit: int = 100,
) -> type:
"""
Factory to create a specific RLDatasetBuilder class for a given adapter.
"""
if not TINKER_AVAILABLE:
return object
@chz.chz
class CustomBuilder(RLDatasetBuilder):
batch_size: int
model_name_for_tokenizer: str
renderer_name: str
group_size: int
seed: int = 0
async def __call__(self) -> tuple[RLDataset, RLDataset]:
tokenizer = tokenizer_utils.get_tokenizer(self.model_name_for_tokenizer)
renderer = renderers.get_renderer(self.renderer_name, tokenizer=tokenizer)
# Create adapter
adapter = adapter_factory()
# Get convo prefix if needed
convo_prefix = convo_prefix_factory() if convo_prefix_factory else None
# Bind renderer and prefix to row converter if needed
# We'll wrap the row_converter to inject renderer and prefix
def bound_row_converter(row, g_size):
return row_converter(row, g_size, renderer, convo_prefix)
train_ds = EvalProtocolRLDataset(
adapter=adapter,
row_converter=bound_row_converter,
batch_size=self.batch_size,
group_size=self.group_size,
split="train",
limit=train_limit,
)
test_ds = EvalProtocolRLDataset(
adapter=adapter,
row_converter=bound_row_converter,
batch_size=self.batch_size,
group_size=self.group_size,
split="test",
limit=test_limit,
)
return (train_ds, test_ds)
return CustomBuilder