-
Notifications
You must be signed in to change notification settings - Fork 16
Tinker example #340
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Tinker example #340
Changes from all commits
Commits
Show all changes
3 commits
Select commit
Hold shift + click to select a range
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,197 @@ | ||
| import logging | ||
| import math | ||
| import asyncio | ||
| import inspect | ||
| from typing import Any, Callable, Literal, Optional, Sequence, List | ||
|
|
||
| try: | ||
| import chz | ||
| from tinker_cookbook import renderers, tokenizer_utils | ||
| from tinker_cookbook.rl.problem_env import ProblemGroupBuilder | ||
| from tinker_cookbook.rl.types import RLDataset, RLDatasetBuilder | ||
| from tinker_cookbook.eval.evaluators import SamplingClientEvaluator | ||
| import tinker | ||
|
|
||
| TINKER_AVAILABLE = True | ||
| except ImportError: | ||
| TINKER_AVAILABLE = False | ||
| # Dummy classes to avoid NameError when defining the class if imports fail | ||
| # but we should probably raise an error if these are instantiated without dependencies | ||
| RLDataset = object | ||
| RLDatasetBuilder = object | ||
| ProblemGroupBuilder = object | ||
| SamplingClientEvaluator = object | ||
|
|
||
| from eval_protocol.adapters.base import BaseAdapter | ||
| from eval_protocol.models import EvaluationRow | ||
| from eval_protocol.pytest.types import RolloutProcessorConfig | ||
|
|
||
| logger = logging.getLogger(__name__) | ||
|
|
||
|
|
||
| class EvalProtocolRLDataset(RLDataset): | ||
| def __init__( | ||
| self, | ||
| adapter: BaseAdapter, | ||
| row_converter: Callable[[Any, int], Optional[ProblemGroupBuilder]], | ||
| batch_size: int, | ||
| group_size: int, | ||
| split: str = "train", | ||
| limit: Optional[int] = None, | ||
| ): | ||
| if not TINKER_AVAILABLE: | ||
| raise ImportError("tinker-cookbook is required to use EvalProtocolRLDataset") | ||
|
|
||
| self.adapter = adapter | ||
| self.row_converter = row_converter | ||
| self.batch_size = batch_size | ||
| self.group_size = group_size if split == "train" else 1 | ||
|
|
||
| logger.info(f"Fetching {limit if limit else 'all'} rows from adapter for split {split}...") | ||
| self.rows = list(self.adapter.get_evaluation_rows(split=split, limit=limit)) | ||
| logger.info(f"Loaded {len(self.rows)} rows.") | ||
|
|
||
| def get_batch(self, index: int) -> Sequence[ProblemGroupBuilder]: | ||
| batch_start = index * self.batch_size | ||
| batch_end = min((index + 1) * self.batch_size, len(self.rows)) | ||
|
|
||
| batch_builders = [] | ||
| for i in range(batch_start, batch_end): | ||
| row = self.rows[i] | ||
| # row_converter should take the row and group_size and return a ProblemGroupBuilder | ||
| builder = self.row_converter(row, self.group_size) | ||
| if builder is not None: | ||
| batch_builders.append(builder) | ||
|
|
||
| return batch_builders | ||
|
|
||
| def __len__(self) -> int: | ||
| return math.ceil(len(self.rows) / self.batch_size) | ||
|
|
||
|
|
||
| if TINKER_AVAILABLE: | ||
|
|
||
| class EvalProtocolEvaluator(SamplingClientEvaluator): | ||
| def __init__( | ||
| self, | ||
| rows: List[EvaluationRow], | ||
| eval_func: Callable[[EvaluationRow], EvaluationRow], | ||
| rollout_processor_cls: Any, | ||
| model_name: str, | ||
| renderer_name: str, | ||
| max_tokens: int = 512, | ||
| temperature: float = 0.0, | ||
| ): | ||
| self.rows = rows | ||
|
|
||
| # If the function is a dual_mode_wrapper (from @evaluation_test), unwrap it to get the raw function logic. | ||
| # This avoids the overhead of the wrapper which is designed for pytest execution. | ||
| if hasattr(eval_func, "_origin_func"): | ||
| self.eval_func = eval_func._origin_func | ||
| else: | ||
| self.eval_func = eval_func | ||
|
|
||
| self.rollout_processor_cls = rollout_processor_cls | ||
| self.model_name = model_name | ||
| self.renderer_name = renderer_name | ||
| self.max_tokens = max_tokens | ||
| self.temperature = temperature | ||
|
|
||
| async def __call__(self, sampling_client: tinker.SamplingClient) -> dict[str, float]: | ||
| processor = self.rollout_processor_cls( | ||
| sampling_client=sampling_client, model_name=self.model_name, renderer_name=self.renderer_name | ||
| ) | ||
| processor.setup() | ||
|
|
||
| # Config for rollout | ||
| config = RolloutProcessorConfig( | ||
| completion_params={ | ||
| "max_tokens": self.max_tokens, | ||
| "temperature": self.temperature, | ||
| }, | ||
| semaphore=asyncio.Semaphore(10), # Concurrency limit | ||
| mcp_config_path="", # Not used | ||
| steps=1, | ||
| logger=None, # Optional logger | ||
|
Check failure on line 115 in eval_protocol/integrations/tinker_cookbook.py
|
||
| kwargs={}, | ||
| ) | ||
|
|
||
| # Run rollouts | ||
| tasks = processor(self.rows, config) | ||
| processed_rows = await asyncio.gather(*tasks) | ||
|
|
||
| # Score | ||
| scores = [] | ||
| for row in processed_rows: | ||
| # Call the function logic (sync or async) | ||
| res = self.eval_func(row) | ||
|
|
||
| if inspect.isawaitable(res): | ||
| scored_row = await res | ||
| else: | ||
| scored_row = res | ||
|
|
||
| if scored_row.evaluation_result and scored_row.evaluation_result.score is not None: | ||
| scores.append(scored_row.evaluation_result.score) | ||
|
|
||
| mean_score = sum(scores) / len(scores) if scores else 0.0 | ||
| return {"accuracy": mean_score} | ||
|
|
||
|
|
||
| def create_eval_protocol_dataset_builder( | ||
| adapter_factory: Callable[[], BaseAdapter], | ||
| row_converter: Callable[[Any, int, Any, Any], Optional[ProblemGroupBuilder]], | ||
| convo_prefix_factory: Optional[Callable[[], list]] = None, | ||
| train_limit: int = 1000, | ||
| test_limit: int = 100, | ||
| ) -> type: | ||
| """ | ||
| Factory to create a specific RLDatasetBuilder class for a given adapter. | ||
| """ | ||
| if not TINKER_AVAILABLE: | ||
| return object | ||
|
|
||
| @chz.chz | ||
| class CustomBuilder(RLDatasetBuilder): | ||
| batch_size: int | ||
| model_name_for_tokenizer: str | ||
| renderer_name: str | ||
| group_size: int | ||
| seed: int = 0 | ||
|
|
||
| async def __call__(self) -> tuple[RLDataset, RLDataset]: | ||
|
Check failure on line 162 in eval_protocol/integrations/tinker_cookbook.py
|
||
| tokenizer = tokenizer_utils.get_tokenizer(self.model_name_for_tokenizer) | ||
| renderer = renderers.get_renderer(self.renderer_name, tokenizer=tokenizer) | ||
|
|
||
| # Create adapter | ||
| adapter = adapter_factory() | ||
|
|
||
| # Get convo prefix if needed | ||
| convo_prefix = convo_prefix_factory() if convo_prefix_factory else None | ||
|
|
||
| # Bind renderer and prefix to row converter if needed | ||
| # We'll wrap the row_converter to inject renderer and prefix | ||
| def bound_row_converter(row, g_size): | ||
| return row_converter(row, g_size, renderer, convo_prefix) | ||
|
|
||
| train_ds = EvalProtocolRLDataset( | ||
| adapter=adapter, | ||
| row_converter=bound_row_converter, | ||
| batch_size=self.batch_size, | ||
| group_size=self.group_size, | ||
| split="train", | ||
| limit=train_limit, | ||
| ) | ||
|
|
||
| test_ds = EvalProtocolRLDataset( | ||
| adapter=adapter, | ||
| row_converter=bound_row_converter, | ||
| batch_size=self.batch_size, | ||
| group_size=self.group_size, | ||
| split="test", | ||
| limit=test_limit, | ||
| ) | ||
|
|
||
| return (train_ds, test_ds) | ||
|
|
||
| return CustomBuilder | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,170 @@ | ||
| import asyncio | ||
| import logging | ||
| import os | ||
| import time | ||
| import traceback | ||
| from typing import Any, Dict, List, Optional, Union | ||
|
|
||
| from eval_protocol.dataset_logger import default_logger | ||
| from eval_protocol.models import EvaluationRow, Message | ||
| from eval_protocol.pytest.rollout_processor import RolloutProcessor | ||
| from eval_protocol.pytest.types import RolloutProcessorConfig | ||
|
|
||
| try: | ||
| import tinker | ||
| from tinker_cookbook import renderers, tokenizer_utils | ||
|
|
||
| TINKER_AVAILABLE = True | ||
| except ImportError: | ||
| TINKER_AVAILABLE = False | ||
|
|
||
| logger = logging.getLogger(__name__) | ||
|
|
||
|
|
||
| class TinkerRolloutProcessor(RolloutProcessor): | ||
| """ | ||
| Rollout processor that uses a Tinker SamplingClient to generate responses. | ||
| """ | ||
|
|
||
| def __init__( | ||
| self, | ||
| sampling_client: Optional[Any] = None, | ||
| model_name: Optional[str] = None, | ||
| renderer_name: str = "llama3", | ||
| ) -> None: | ||
| """ | ||
| Args: | ||
| sampling_client: Pre-initialized tinker.SamplingClient. If None, one will be created using model_name. | ||
| model_name: Name of the model to use (if sampling_client is None). | ||
| renderer_name: Name of the renderer to use for formatting messages. | ||
| """ | ||
| if not TINKER_AVAILABLE: | ||
| raise ImportError("tinker-cookbook is required to use TinkerRolloutProcessor") | ||
|
|
||
| self.sampling_client = sampling_client | ||
| self.model_name = model_name | ||
| self.renderer_name = renderer_name | ||
| self.renderer = None | ||
| self.tokenizer = None | ||
|
|
||
| def setup(self) -> None: | ||
| """Setup resources.""" | ||
| if self.sampling_client is None: | ||
| if self.model_name is None: | ||
| raise ValueError("Either sampling_client or model_name must be provided") | ||
|
|
||
| # Initialize Tinker service client | ||
| # This assumes TINKER_API_KEY is set in env | ||
| service_client = tinker.ServiceClient() | ||
| self.sampling_client = service_client.create_sampling_client(base_model=self.model_name) | ||
|
|
||
| # Initialize tokenizer and renderer | ||
| # We need the model name to get the correct tokenizer. | ||
| # If sampling_client was provided without model_name, we might need to infer it or require it. | ||
| if self.model_name: | ||
| self.tokenizer = tokenizer_utils.get_tokenizer(self.model_name) | ||
| else: | ||
| # Fallback or try to get from client if possible? | ||
| # For now, require model_name even if client is passed, or use a default | ||
| # But usually we want the renderer to match the model. | ||
| # Let's assume Llama-3 tokenizer if not specified for now or raise error | ||
| raise ValueError("model_name is required to initialize tokenizer/renderer") | ||
|
|
||
| self.renderer = renderers.get_renderer(self.renderer_name, tokenizer=self.tokenizer) | ||
|
|
||
| def __call__(self, rows: List[EvaluationRow], config: RolloutProcessorConfig) -> List[asyncio.Task[EvaluationRow]]: | ||
| """Generate rollout tasks using Tinker.""" | ||
|
|
||
| async def process_row(row: EvaluationRow) -> EvaluationRow: | ||
| start_time = time.perf_counter() | ||
|
|
||
| if not row.messages: | ||
| raise ValueError("Messages is empty") | ||
|
|
||
| # Prepare prompt using renderer | ||
| # Convert messages to Tinker ModelInput | ||
| # We need to convert EvaluationRow messages (standard format) to the renderer's expected input | ||
| # The renderer expects a list of dicts or objects with role/content | ||
| # eval_protocol Message objects have role/content attributes, which should work if renderer supports objects | ||
| # checking renderer code... it typically iterates and accesses keys or attributes. | ||
| # Let's convert to dicts to be safe. | ||
|
|
||
| convo = [ | ||
| {"role": m.role, "content": m.content} | ||
| for m in row.messages | ||
| if m.role in ["system", "user", "assistant"] | ||
| ] | ||
|
|
||
| prompt = self.renderer.build_generation_prompt(convo) | ||
|
|
||
| # Prepare sampling params | ||
| # Map config.completion_params to Tinker SamplingParams | ||
| # Default values matching standard configs | ||
| max_tokens = config.completion_params.get("max_tokens", 512) | ||
| temperature = config.completion_params.get("temperature", 1.0) | ||
| top_p = config.completion_params.get("top_p", 1.0) | ||
| top_k = config.completion_params.get("top_k", -1) | ||
|
|
||
| # Get stop sequences from renderer | ||
| stop_sequences = self.renderer.get_stop_sequences() | ||
| # Ensure stop_sequences is a list | ||
| if stop_sequences is None: | ||
| stop_sequences = [] | ||
|
|
||
| sampling_params = tinker.SamplingParams( | ||
| max_tokens=int(max_tokens), | ||
| temperature=float(temperature), | ||
| top_p=float(top_p), | ||
| top_k=int(top_k), | ||
| stop=stop_sequences, | ||
| ) | ||
|
|
||
| # Call Tinker API | ||
| try: | ||
| sample_result = await self.sampling_client.sample_async( | ||
| prompt=prompt, num_samples=1, sampling_params=sampling_params | ||
| ) | ||
|
|
||
| # Parse response | ||
| # renderer.parse_response returns (Message, bool) | ||
| sampled_tokens = sample_result.sequences[0].tokens | ||
| message, parse_success = self.renderer.parse_response(sampled_tokens) | ||
|
|
||
| if message: | ||
| assistant_content = message["content"] | ||
| else: | ||
| assistant_content = "" | ||
|
|
||
| except Exception as e: | ||
| # Try to extract more info if '0' is not helpful | ||
| error_details = str(e) | ||
| if error_details == "0": | ||
| try: | ||
| error_details = f"Code: {e.code}, Message: {getattr(e, 'message', 'unknown')}" | ||
| except Exception as e2: | ||
| pass | ||
| # Log full traceback for debugging | ||
| tb_str = traceback.format_exc() | ||
| logger.error(f"Tinker sampling failed: {error_details}\nTraceback:\n{tb_str}") | ||
| assistant_content = "" # Or handle error more gracefully | ||
| # Could set status on row | ||
|
|
||
| # Update row | ||
| new_messages = list(row.messages) + [Message(role="assistant", content=assistant_content)] | ||
| row.messages = new_messages | ||
| row.execution_metadata.duration_seconds = time.perf_counter() - start_time | ||
|
|
||
| # Log usage (approximate since Tinker might not return usage stats in same format) | ||
| # We can count tokens ourselves | ||
| row.execution_metadata.usage = None # Placeholder | ||
|
|
||
| default_logger.log(row) | ||
| return row | ||
|
|
||
| semaphore = config.semaphore | ||
|
|
||
| async def _sem_wrapper(r: EvaluationRow) -> EvaluationRow: | ||
| async with semaphore: | ||
| return await process_row(r) | ||
|
|
||
| return [asyncio.create_task(_sem_wrapper(row)) for row in rows] | ||
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.