Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
197 changes: 197 additions & 0 deletions eval_protocol/integrations/tinker_cookbook.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,197 @@
import logging
import math
import asyncio
import inspect
from typing import Any, Callable, Literal, Optional, Sequence, List

try:
import chz
from tinker_cookbook import renderers, tokenizer_utils
from tinker_cookbook.rl.problem_env import ProblemGroupBuilder
from tinker_cookbook.rl.types import RLDataset, RLDatasetBuilder
from tinker_cookbook.eval.evaluators import SamplingClientEvaluator
import tinker

TINKER_AVAILABLE = True
except ImportError:
TINKER_AVAILABLE = False
# Dummy classes to avoid NameError when defining the class if imports fail
# but we should probably raise an error if these are instantiated without dependencies
RLDataset = object
RLDatasetBuilder = object
ProblemGroupBuilder = object
SamplingClientEvaluator = object

from eval_protocol.adapters.base import BaseAdapter
from eval_protocol.models import EvaluationRow
from eval_protocol.pytest.types import RolloutProcessorConfig

logger = logging.getLogger(__name__)


class EvalProtocolRLDataset(RLDataset):

Check failure on line 32 in eval_protocol/integrations/tinker_cookbook.py

View workflow job for this annotation

GitHub Actions / Lint & Type Check

Argument to class must be a base class (reportGeneralTypeIssues)
def __init__(
self,
adapter: BaseAdapter,
row_converter: Callable[[Any, int], Optional[ProblemGroupBuilder]],

Check failure on line 36 in eval_protocol/integrations/tinker_cookbook.py

View workflow job for this annotation

GitHub Actions / Lint & Type Check

Variable not allowed in type expression (reportInvalidTypeForm)
batch_size: int,
group_size: int,
split: str = "train",
limit: Optional[int] = None,
):
if not TINKER_AVAILABLE:
raise ImportError("tinker-cookbook is required to use EvalProtocolRLDataset")

self.adapter = adapter
self.row_converter = row_converter
self.batch_size = batch_size
self.group_size = group_size if split == "train" else 1

logger.info(f"Fetching {limit if limit else 'all'} rows from adapter for split {split}...")
self.rows = list(self.adapter.get_evaluation_rows(split=split, limit=limit))
logger.info(f"Loaded {len(self.rows)} rows.")

def get_batch(self, index: int) -> Sequence[ProblemGroupBuilder]:

Check failure on line 54 in eval_protocol/integrations/tinker_cookbook.py

View workflow job for this annotation

GitHub Actions / Lint & Type Check

Variable not allowed in type expression (reportInvalidTypeForm)
batch_start = index * self.batch_size
batch_end = min((index + 1) * self.batch_size, len(self.rows))

batch_builders = []
for i in range(batch_start, batch_end):
row = self.rows[i]
# row_converter should take the row and group_size and return a ProblemGroupBuilder
builder = self.row_converter(row, self.group_size)
if builder is not None:
batch_builders.append(builder)

return batch_builders

def __len__(self) -> int:
return math.ceil(len(self.rows) / self.batch_size)


if TINKER_AVAILABLE:

class EvalProtocolEvaluator(SamplingClientEvaluator):

Check failure on line 74 in eval_protocol/integrations/tinker_cookbook.py

View workflow job for this annotation

GitHub Actions / Lint & Type Check

Argument to class must be a base class (reportGeneralTypeIssues)
def __init__(
self,
rows: List[EvaluationRow],
eval_func: Callable[[EvaluationRow], EvaluationRow],
rollout_processor_cls: Any,
model_name: str,
renderer_name: str,
max_tokens: int = 512,
temperature: float = 0.0,
):
self.rows = rows

# If the function is a dual_mode_wrapper (from @evaluation_test), unwrap it to get the raw function logic.
# This avoids the overhead of the wrapper which is designed for pytest execution.
if hasattr(eval_func, "_origin_func"):
self.eval_func = eval_func._origin_func
else:
self.eval_func = eval_func

self.rollout_processor_cls = rollout_processor_cls
self.model_name = model_name
self.renderer_name = renderer_name
self.max_tokens = max_tokens
self.temperature = temperature

async def __call__(self, sampling_client: tinker.SamplingClient) -> dict[str, float]:
processor = self.rollout_processor_cls(
sampling_client=sampling_client, model_name=self.model_name, renderer_name=self.renderer_name
)
processor.setup()

# Config for rollout
config = RolloutProcessorConfig(
completion_params={
"max_tokens": self.max_tokens,
"temperature": self.temperature,
},
semaphore=asyncio.Semaphore(10), # Concurrency limit
mcp_config_path="", # Not used
steps=1,
logger=None, # Optional logger

Check failure on line 115 in eval_protocol/integrations/tinker_cookbook.py

View workflow job for this annotation

GitHub Actions / Lint & Type Check

Argument of type "None" cannot be assigned to parameter "logger" of type "DatasetLogger" in function "__init__"   "None" is not assignable to "DatasetLogger" (reportArgumentType)
kwargs={},
)

# Run rollouts
tasks = processor(self.rows, config)
processed_rows = await asyncio.gather(*tasks)

# Score
scores = []
for row in processed_rows:
# Call the function logic (sync or async)
res = self.eval_func(row)

if inspect.isawaitable(res):
scored_row = await res
else:
scored_row = res

if scored_row.evaluation_result and scored_row.evaluation_result.score is not None:
scores.append(scored_row.evaluation_result.score)

mean_score = sum(scores) / len(scores) if scores else 0.0
return {"accuracy": mean_score}


def create_eval_protocol_dataset_builder(
adapter_factory: Callable[[], BaseAdapter],
row_converter: Callable[[Any, int, Any, Any], Optional[ProblemGroupBuilder]],

Check failure on line 143 in eval_protocol/integrations/tinker_cookbook.py

View workflow job for this annotation

GitHub Actions / Lint & Type Check

Variable not allowed in type expression (reportInvalidTypeForm)
convo_prefix_factory: Optional[Callable[[], list]] = None,
train_limit: int = 1000,
test_limit: int = 100,
) -> type:
"""
Factory to create a specific RLDatasetBuilder class for a given adapter.
"""
if not TINKER_AVAILABLE:
return object

@chz.chz
class CustomBuilder(RLDatasetBuilder):

Check failure on line 155 in eval_protocol/integrations/tinker_cookbook.py

View workflow job for this annotation

GitHub Actions / Lint & Type Check

Argument to class must be a base class (reportGeneralTypeIssues)
batch_size: int
model_name_for_tokenizer: str
renderer_name: str
group_size: int
seed: int = 0

async def __call__(self) -> tuple[RLDataset, RLDataset]:

Check failure on line 162 in eval_protocol/integrations/tinker_cookbook.py

View workflow job for this annotation

GitHub Actions / Lint & Type Check

Variable not allowed in type expression (reportInvalidTypeForm)

Check failure on line 162 in eval_protocol/integrations/tinker_cookbook.py

View workflow job for this annotation

GitHub Actions / Lint & Type Check

Variable not allowed in type expression (reportInvalidTypeForm)
tokenizer = tokenizer_utils.get_tokenizer(self.model_name_for_tokenizer)

Check failure on line 163 in eval_protocol/integrations/tinker_cookbook.py

View workflow job for this annotation

GitHub Actions / Lint & Type Check

"get_tokenizer" is not a known attribute of module "tinker_cookbook" (reportAttributeAccessIssue)
renderer = renderers.get_renderer(self.renderer_name, tokenizer=tokenizer)

# Create adapter
adapter = adapter_factory()

# Get convo prefix if needed
convo_prefix = convo_prefix_factory() if convo_prefix_factory else None

# Bind renderer and prefix to row converter if needed
# We'll wrap the row_converter to inject renderer and prefix
def bound_row_converter(row, g_size):
return row_converter(row, g_size, renderer, convo_prefix)

train_ds = EvalProtocolRLDataset(
adapter=adapter,
row_converter=bound_row_converter,
batch_size=self.batch_size,
group_size=self.group_size,
split="train",
limit=train_limit,
)

test_ds = EvalProtocolRLDataset(
adapter=adapter,
row_converter=bound_row_converter,
batch_size=self.batch_size,
group_size=self.group_size,
split="test",
limit=test_limit,
)

return (train_ds, test_ds)

return CustomBuilder
170 changes: 170 additions & 0 deletions eval_protocol/integrations/tinker_rollout_processor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,170 @@
import asyncio
import logging
import os
import time
import traceback
from typing import Any, Dict, List, Optional, Union

from eval_protocol.dataset_logger import default_logger
from eval_protocol.models import EvaluationRow, Message
from eval_protocol.pytest.rollout_processor import RolloutProcessor
from eval_protocol.pytest.types import RolloutProcessorConfig

try:
import tinker
from tinker_cookbook import renderers, tokenizer_utils

TINKER_AVAILABLE = True
except ImportError:
TINKER_AVAILABLE = False

logger = logging.getLogger(__name__)


class TinkerRolloutProcessor(RolloutProcessor):
"""
Rollout processor that uses a Tinker SamplingClient to generate responses.
"""

def __init__(
self,
sampling_client: Optional[Any] = None,
model_name: Optional[str] = None,
renderer_name: str = "llama3",
) -> None:
"""
Args:
sampling_client: Pre-initialized tinker.SamplingClient. If None, one will be created using model_name.
model_name: Name of the model to use (if sampling_client is None).
renderer_name: Name of the renderer to use for formatting messages.
"""
if not TINKER_AVAILABLE:
raise ImportError("tinker-cookbook is required to use TinkerRolloutProcessor")

self.sampling_client = sampling_client
self.model_name = model_name
self.renderer_name = renderer_name
self.renderer = None
self.tokenizer = None

def setup(self) -> None:
"""Setup resources."""
if self.sampling_client is None:
if self.model_name is None:
raise ValueError("Either sampling_client or model_name must be provided")

# Initialize Tinker service client
# This assumes TINKER_API_KEY is set in env
service_client = tinker.ServiceClient()
self.sampling_client = service_client.create_sampling_client(base_model=self.model_name)

# Initialize tokenizer and renderer
# We need the model name to get the correct tokenizer.
# If sampling_client was provided without model_name, we might need to infer it or require it.
if self.model_name:
self.tokenizer = tokenizer_utils.get_tokenizer(self.model_name)
else:
# Fallback or try to get from client if possible?
# For now, require model_name even if client is passed, or use a default
# But usually we want the renderer to match the model.
# Let's assume Llama-3 tokenizer if not specified for now or raise error
raise ValueError("model_name is required to initialize tokenizer/renderer")

self.renderer = renderers.get_renderer(self.renderer_name, tokenizer=self.tokenizer)

def __call__(self, rows: List[EvaluationRow], config: RolloutProcessorConfig) -> List[asyncio.Task[EvaluationRow]]:
"""Generate rollout tasks using Tinker."""

async def process_row(row: EvaluationRow) -> EvaluationRow:
start_time = time.perf_counter()

if not row.messages:
raise ValueError("Messages is empty")

# Prepare prompt using renderer
# Convert messages to Tinker ModelInput
# We need to convert EvaluationRow messages (standard format) to the renderer's expected input
# The renderer expects a list of dicts or objects with role/content
# eval_protocol Message objects have role/content attributes, which should work if renderer supports objects
# checking renderer code... it typically iterates and accesses keys or attributes.
# Let's convert to dicts to be safe.

convo = [
{"role": m.role, "content": m.content}
for m in row.messages
if m.role in ["system", "user", "assistant"]
]

prompt = self.renderer.build_generation_prompt(convo)

# Prepare sampling params
# Map config.completion_params to Tinker SamplingParams
# Default values matching standard configs
max_tokens = config.completion_params.get("max_tokens", 512)
temperature = config.completion_params.get("temperature", 1.0)
top_p = config.completion_params.get("top_p", 1.0)
top_k = config.completion_params.get("top_k", -1)

# Get stop sequences from renderer
stop_sequences = self.renderer.get_stop_sequences()
# Ensure stop_sequences is a list
if stop_sequences is None:
stop_sequences = []

sampling_params = tinker.SamplingParams(
max_tokens=int(max_tokens),
temperature=float(temperature),
top_p=float(top_p),
top_k=int(top_k),
stop=stop_sequences,
)

# Call Tinker API
try:
sample_result = await self.sampling_client.sample_async(
prompt=prompt, num_samples=1, sampling_params=sampling_params
)

# Parse response
# renderer.parse_response returns (Message, bool)
sampled_tokens = sample_result.sequences[0].tokens
message, parse_success = self.renderer.parse_response(sampled_tokens)

if message:
assistant_content = message["content"]
else:
assistant_content = ""

except Exception as e:
# Try to extract more info if '0' is not helpful
error_details = str(e)
if error_details == "0":
try:
error_details = f"Code: {e.code}, Message: {getattr(e, 'message', 'unknown')}"
except Exception as e2:
pass
# Log full traceback for debugging
tb_str = traceback.format_exc()
logger.error(f"Tinker sampling failed: {error_details}\nTraceback:\n{tb_str}")
assistant_content = "" # Or handle error more gracefully
# Could set status on row

# Update row
new_messages = list(row.messages) + [Message(role="assistant", content=assistant_content)]
row.messages = new_messages
row.execution_metadata.duration_seconds = time.perf_counter() - start_time

# Log usage (approximate since Tinker might not return usage stats in same format)
# We can count tokens ourselves
row.execution_metadata.usage = None # Placeholder

default_logger.log(row)
return row

semaphore = config.semaphore

async def _sem_wrapper(r: EvaluationRow) -> EvaluationRow:
async with semaphore:
return await process_row(r)

return [asyncio.create_task(_sem_wrapper(row)) for row in rows]
Loading
Loading