diff --git a/eval_protocol/integrations/tinker_cookbook.py b/eval_protocol/integrations/tinker_cookbook.py new file mode 100644 index 00000000..ed2e3674 --- /dev/null +++ b/eval_protocol/integrations/tinker_cookbook.py @@ -0,0 +1,197 @@ +import logging +import math +import asyncio +import inspect +from typing import Any, Callable, Literal, Optional, Sequence, List + +try: + import chz + from tinker_cookbook import renderers, tokenizer_utils + from tinker_cookbook.rl.problem_env import ProblemGroupBuilder + from tinker_cookbook.rl.types import RLDataset, RLDatasetBuilder + from tinker_cookbook.eval.evaluators import SamplingClientEvaluator + import tinker + + TINKER_AVAILABLE = True +except ImportError: + TINKER_AVAILABLE = False + # Dummy classes to avoid NameError when defining the class if imports fail + # but we should probably raise an error if these are instantiated without dependencies + RLDataset = object + RLDatasetBuilder = object + ProblemGroupBuilder = object + SamplingClientEvaluator = object + +from eval_protocol.adapters.base import BaseAdapter +from eval_protocol.models import EvaluationRow +from eval_protocol.pytest.types import RolloutProcessorConfig + +logger = logging.getLogger(__name__) + + +class EvalProtocolRLDataset(RLDataset): + def __init__( + self, + adapter: BaseAdapter, + row_converter: Callable[[Any, int], Optional[ProblemGroupBuilder]], + batch_size: int, + group_size: int, + split: str = "train", + limit: Optional[int] = None, + ): + if not TINKER_AVAILABLE: + raise ImportError("tinker-cookbook is required to use EvalProtocolRLDataset") + + self.adapter = adapter + self.row_converter = row_converter + self.batch_size = batch_size + self.group_size = group_size if split == "train" else 1 + + logger.info(f"Fetching {limit if limit else 'all'} rows from adapter for split {split}...") + self.rows = list(self.adapter.get_evaluation_rows(split=split, limit=limit)) + logger.info(f"Loaded {len(self.rows)} rows.") + + def get_batch(self, index: int) -> Sequence[ProblemGroupBuilder]: + batch_start = index * self.batch_size + batch_end = min((index + 1) * self.batch_size, len(self.rows)) + + batch_builders = [] + for i in range(batch_start, batch_end): + row = self.rows[i] + # row_converter should take the row and group_size and return a ProblemGroupBuilder + builder = self.row_converter(row, self.group_size) + if builder is not None: + batch_builders.append(builder) + + return batch_builders + + def __len__(self) -> int: + return math.ceil(len(self.rows) / self.batch_size) + + +if TINKER_AVAILABLE: + + class EvalProtocolEvaluator(SamplingClientEvaluator): + def __init__( + self, + rows: List[EvaluationRow], + eval_func: Callable[[EvaluationRow], EvaluationRow], + rollout_processor_cls: Any, + model_name: str, + renderer_name: str, + max_tokens: int = 512, + temperature: float = 0.0, + ): + self.rows = rows + + # If the function is a dual_mode_wrapper (from @evaluation_test), unwrap it to get the raw function logic. + # This avoids the overhead of the wrapper which is designed for pytest execution. + if hasattr(eval_func, "_origin_func"): + self.eval_func = eval_func._origin_func + else: + self.eval_func = eval_func + + self.rollout_processor_cls = rollout_processor_cls + self.model_name = model_name + self.renderer_name = renderer_name + self.max_tokens = max_tokens + self.temperature = temperature + + async def __call__(self, sampling_client: tinker.SamplingClient) -> dict[str, float]: + processor = self.rollout_processor_cls( + sampling_client=sampling_client, model_name=self.model_name, renderer_name=self.renderer_name + ) + processor.setup() + + # Config for rollout + config = RolloutProcessorConfig( + completion_params={ + "max_tokens": self.max_tokens, + "temperature": self.temperature, + }, + semaphore=asyncio.Semaphore(10), # Concurrency limit + mcp_config_path="", # Not used + steps=1, + logger=None, # Optional logger + kwargs={}, + ) + + # Run rollouts + tasks = processor(self.rows, config) + processed_rows = await asyncio.gather(*tasks) + + # Score + scores = [] + for row in processed_rows: + # Call the function logic (sync or async) + res = self.eval_func(row) + + if inspect.isawaitable(res): + scored_row = await res + else: + scored_row = res + + if scored_row.evaluation_result and scored_row.evaluation_result.score is not None: + scores.append(scored_row.evaluation_result.score) + + mean_score = sum(scores) / len(scores) if scores else 0.0 + return {"accuracy": mean_score} + + +def create_eval_protocol_dataset_builder( + adapter_factory: Callable[[], BaseAdapter], + row_converter: Callable[[Any, int, Any, Any], Optional[ProblemGroupBuilder]], + convo_prefix_factory: Optional[Callable[[], list]] = None, + train_limit: int = 1000, + test_limit: int = 100, +) -> type: + """ + Factory to create a specific RLDatasetBuilder class for a given adapter. + """ + if not TINKER_AVAILABLE: + return object + + @chz.chz + class CustomBuilder(RLDatasetBuilder): + batch_size: int + model_name_for_tokenizer: str + renderer_name: str + group_size: int + seed: int = 0 + + async def __call__(self) -> tuple[RLDataset, RLDataset]: + tokenizer = tokenizer_utils.get_tokenizer(self.model_name_for_tokenizer) + renderer = renderers.get_renderer(self.renderer_name, tokenizer=tokenizer) + + # Create adapter + adapter = adapter_factory() + + # Get convo prefix if needed + convo_prefix = convo_prefix_factory() if convo_prefix_factory else None + + # Bind renderer and prefix to row converter if needed + # We'll wrap the row_converter to inject renderer and prefix + def bound_row_converter(row, g_size): + return row_converter(row, g_size, renderer, convo_prefix) + + train_ds = EvalProtocolRLDataset( + adapter=adapter, + row_converter=bound_row_converter, + batch_size=self.batch_size, + group_size=self.group_size, + split="train", + limit=train_limit, + ) + + test_ds = EvalProtocolRLDataset( + adapter=adapter, + row_converter=bound_row_converter, + batch_size=self.batch_size, + group_size=self.group_size, + split="test", + limit=test_limit, + ) + + return (train_ds, test_ds) + + return CustomBuilder diff --git a/eval_protocol/integrations/tinker_rollout_processor.py b/eval_protocol/integrations/tinker_rollout_processor.py new file mode 100644 index 00000000..5edf4cf6 --- /dev/null +++ b/eval_protocol/integrations/tinker_rollout_processor.py @@ -0,0 +1,170 @@ +import asyncio +import logging +import os +import time +import traceback +from typing import Any, Dict, List, Optional, Union + +from eval_protocol.dataset_logger import default_logger +from eval_protocol.models import EvaluationRow, Message +from eval_protocol.pytest.rollout_processor import RolloutProcessor +from eval_protocol.pytest.types import RolloutProcessorConfig + +try: + import tinker + from tinker_cookbook import renderers, tokenizer_utils + + TINKER_AVAILABLE = True +except ImportError: + TINKER_AVAILABLE = False + +logger = logging.getLogger(__name__) + + +class TinkerRolloutProcessor(RolloutProcessor): + """ + Rollout processor that uses a Tinker SamplingClient to generate responses. + """ + + def __init__( + self, + sampling_client: Optional[Any] = None, + model_name: Optional[str] = None, + renderer_name: str = "llama3", + ) -> None: + """ + Args: + sampling_client: Pre-initialized tinker.SamplingClient. If None, one will be created using model_name. + model_name: Name of the model to use (if sampling_client is None). + renderer_name: Name of the renderer to use for formatting messages. + """ + if not TINKER_AVAILABLE: + raise ImportError("tinker-cookbook is required to use TinkerRolloutProcessor") + + self.sampling_client = sampling_client + self.model_name = model_name + self.renderer_name = renderer_name + self.renderer = None + self.tokenizer = None + + def setup(self) -> None: + """Setup resources.""" + if self.sampling_client is None: + if self.model_name is None: + raise ValueError("Either sampling_client or model_name must be provided") + + # Initialize Tinker service client + # This assumes TINKER_API_KEY is set in env + service_client = tinker.ServiceClient() + self.sampling_client = service_client.create_sampling_client(base_model=self.model_name) + + # Initialize tokenizer and renderer + # We need the model name to get the correct tokenizer. + # If sampling_client was provided without model_name, we might need to infer it or require it. + if self.model_name: + self.tokenizer = tokenizer_utils.get_tokenizer(self.model_name) + else: + # Fallback or try to get from client if possible? + # For now, require model_name even if client is passed, or use a default + # But usually we want the renderer to match the model. + # Let's assume Llama-3 tokenizer if not specified for now or raise error + raise ValueError("model_name is required to initialize tokenizer/renderer") + + self.renderer = renderers.get_renderer(self.renderer_name, tokenizer=self.tokenizer) + + def __call__(self, rows: List[EvaluationRow], config: RolloutProcessorConfig) -> List[asyncio.Task[EvaluationRow]]: + """Generate rollout tasks using Tinker.""" + + async def process_row(row: EvaluationRow) -> EvaluationRow: + start_time = time.perf_counter() + + if not row.messages: + raise ValueError("Messages is empty") + + # Prepare prompt using renderer + # Convert messages to Tinker ModelInput + # We need to convert EvaluationRow messages (standard format) to the renderer's expected input + # The renderer expects a list of dicts or objects with role/content + # eval_protocol Message objects have role/content attributes, which should work if renderer supports objects + # checking renderer code... it typically iterates and accesses keys or attributes. + # Let's convert to dicts to be safe. + + convo = [ + {"role": m.role, "content": m.content} + for m in row.messages + if m.role in ["system", "user", "assistant"] + ] + + prompt = self.renderer.build_generation_prompt(convo) + + # Prepare sampling params + # Map config.completion_params to Tinker SamplingParams + # Default values matching standard configs + max_tokens = config.completion_params.get("max_tokens", 512) + temperature = config.completion_params.get("temperature", 1.0) + top_p = config.completion_params.get("top_p", 1.0) + top_k = config.completion_params.get("top_k", -1) + + # Get stop sequences from renderer + stop_sequences = self.renderer.get_stop_sequences() + # Ensure stop_sequences is a list + if stop_sequences is None: + stop_sequences = [] + + sampling_params = tinker.SamplingParams( + max_tokens=int(max_tokens), + temperature=float(temperature), + top_p=float(top_p), + top_k=int(top_k), + stop=stop_sequences, + ) + + # Call Tinker API + try: + sample_result = await self.sampling_client.sample_async( + prompt=prompt, num_samples=1, sampling_params=sampling_params + ) + + # Parse response + # renderer.parse_response returns (Message, bool) + sampled_tokens = sample_result.sequences[0].tokens + message, parse_success = self.renderer.parse_response(sampled_tokens) + + if message: + assistant_content = message["content"] + else: + assistant_content = "" + + except Exception as e: + # Try to extract more info if '0' is not helpful + error_details = str(e) + if error_details == "0": + try: + error_details = f"Code: {e.code}, Message: {getattr(e, 'message', 'unknown')}" + except Exception as e2: + pass + # Log full traceback for debugging + tb_str = traceback.format_exc() + logger.error(f"Tinker sampling failed: {error_details}\nTraceback:\n{tb_str}") + assistant_content = "" # Or handle error more gracefully + # Could set status on row + + # Update row + new_messages = list(row.messages) + [Message(role="assistant", content=assistant_content)] + row.messages = new_messages + row.execution_metadata.duration_seconds = time.perf_counter() - start_time + + # Log usage (approximate since Tinker might not return usage stats in same format) + # We can count tokens ourselves + row.execution_metadata.usage = None # Placeholder + + default_logger.log(row) + return row + + semaphore = config.semaphore + + async def _sem_wrapper(r: EvaluationRow) -> EvaluationRow: + async with semaphore: + return await process_row(r) + + return [asyncio.create_task(_sem_wrapper(row)) for row in rows] diff --git a/examples/tinker_math_rl/README.md b/examples/tinker_math_rl/README.md new file mode 100644 index 00000000..992e14b7 --- /dev/null +++ b/examples/tinker_math_rl/README.md @@ -0,0 +1,43 @@ +# Tinker GSM8K Training Example + +This example demonstrates how to use `eval_protocol` to fetch GSM8K data and train a model using `tinker`'s RL training loop. + +## Prerequisites + +1. **Tinker Cookbook**: Ensure `tinker-cookbook` is available. The script attempts to add `../../../tinker-cookbook` to `sys.path`. +2. **Eval Protocol**: Ensure `eval-protocol` is installed with HuggingFace support. + ```bash + pip install 'eval-protocol[huggingface]' + ``` +3. **Tinker API Key**: You need a Tinker API key. + ```bash + export TINKER_API_KEY=your_api_key_here + ``` + +## Running the Training + +Run the training script with python. We recommend using a small model for testing, such as `Qwen/Qwen3-4B-Instruct-2507`. + +```bash +# Install dependencies +pip install 'eval-protocol[huggingface]' chz tinker + +# Run training +export TINKER_API_KEY=your_api_key_here +python train.py model_name="Qwen/Qwen3-4B-Instruct-2507" groups_per_batch=4 train_limit=100 test_limit=10 +``` + +### Configuration Options + +- `model_name`: The model to train (e.g., `Qwen/Qwen3-4B-Instruct-2507`). +- `groups_per_batch`: Batch size (default: 100). +- `group_size`: Number of samples per problem (default: 4). +- `train_limit`: Number of training examples to fetch (default: 1000). +- `test_limit`: Number of test examples to fetch (default: 100). +- `log_path`: Path to save logs and checkpoints. + +## How it Works + +1. **Data Loading**: The script uses `eval_protocol.adapters.huggingface.create_gsm8k_adapter` to fetch GSM8K data. +2. **Dataset Adaptation**: `EvalProtocolGsm8kDataset` converts `EvaluationRow` objects from `eval_protocol` into `ProblemGroupBuilder` objects expected by `tinker`. +3. **Training**: The standard `tinker` training loop is used to optimize the model. diff --git a/examples/tinker_math_rl/plot_metrics.py b/examples/tinker_math_rl/plot_metrics.py new file mode 100644 index 00000000..fe47b903 --- /dev/null +++ b/examples/tinker_math_rl/plot_metrics.py @@ -0,0 +1,41 @@ +import json +import matplotlib.pyplot as plt +import os + +metrics_file = "/tmp/eval_protocol_integration_v9/metrics.jsonl" +output_file = "python-sdk/examples/tinker_math_rl/reward_plot_integration_v9.png" + +steps = [] +accuracies = [] +rewards = [] +formats = [] + +with open(metrics_file, "r") as f: + for line in f: + data = json.loads(line) + # Only plot training steps (where "step" is present and usually matches training_client/step) + # Some lines might be eval steps (which also have "step" but different keys). + # Let's check if it's a training step or eval step. + + # Based on the log output: + # Eval lines look like: {"step": 0, "test/env/all/reward/total": ...} + # Train lines look like: {"step": 0, "env/all/reward/total": ...} + + if "env/all/reward/total" in data and "test/env/all/reward/total" not in data: + steps.append(data["step"]) + rewards.append(data["env/all/reward/total"]) + accuracies.append(data.get("env/all/correct", 0.0)) + formats.append(data.get("env/all/format", 0.0)) + +plt.figure(figsize=(10, 6)) +plt.plot(steps, accuracies, label="Accuracy", marker="o") +plt.plot(steps, rewards, label="Total Reward", marker="o") +plt.plot(steps, formats, label="Format Compliance", marker="o") + +plt.xlabel("Step") +plt.ylabel("Value") +plt.title("Training Metrics: Eval Protocol Integration V9") +plt.legend() +plt.grid(True) +plt.savefig(output_file) +print(f"Plot saved to {output_file}") diff --git a/examples/tinker_math_rl/reward_plot_integration_v5.png b/examples/tinker_math_rl/reward_plot_integration_v5.png new file mode 100644 index 00000000..8ffdcff8 Binary files /dev/null and b/examples/tinker_math_rl/reward_plot_integration_v5.png differ diff --git a/examples/tinker_math_rl/reward_plot_integration_v9.png b/examples/tinker_math_rl/reward_plot_integration_v9.png new file mode 100644 index 00000000..4d74980f Binary files /dev/null and b/examples/tinker_math_rl/reward_plot_integration_v9.png differ diff --git a/examples/tinker_math_rl/reward_plot_refactored.png b/examples/tinker_math_rl/reward_plot_refactored.png new file mode 100644 index 00000000..b80424a3 Binary files /dev/null and b/examples/tinker_math_rl/reward_plot_refactored.png differ diff --git a/examples/tinker_math_rl/test_gsm8k_eval.py b/examples/tinker_math_rl/test_gsm8k_eval.py new file mode 100644 index 00000000..0a7504dc --- /dev/null +++ b/examples/tinker_math_rl/test_gsm8k_eval.py @@ -0,0 +1,76 @@ +from typing import Any, Dict, List + +from eval_protocol.models import ( + EvaluateResult, + EvaluationRow, + MetricResult, +) +from eval_protocol.pytest.evaluation_test import evaluation_test +from eval_protocol.integrations.tinker_rollout_processor import TinkerRolloutProcessor +from eval_protocol.adapters.huggingface import create_gsm8k_adapter + +# Import grading logic from tinker-cookbook to ensure consistency +try: + from tinker_cookbook.recipes.math_rl.math_grading import grade_answer +except ImportError: + grade_answer = None + + +# Separate data loading for reuse in train.py +def get_gsm8k_input_rows(limit: int = 10) -> List[EvaluationRow]: + adapter = create_gsm8k_adapter() + return list(adapter.get_evaluation_rows(split="test", limit=limit)) + + +# Fetch some rows for the test +gsm8k_input_rows = get_gsm8k_input_rows(limit=10) + + +@evaluation_test( + input_rows=gsm8k_input_rows, + completion_params=[ + { + "max_tokens": 512, + "temperature": 0.0, # Greedy for eval + } + ], + rollout_processor=TinkerRolloutProcessor(model_name="meta-llama/Llama-3.1-8B-Instruct", renderer_name="llama3"), + aggregation_method="mean", + num_runs=1, + max_concurrent_rollouts=4, + mode="pointwise", +) +def test_gsm8k_tinker(row: EvaluationRow) -> EvaluationRow: + assistant_msgs = [m for m in row.messages if m.role == "assistant"] + if not assistant_msgs: + score = 0.0 + reason = "No assistant response" + else: + model_response = assistant_msgs[-1].content + # The content might be a list of content parts, handle that + if model_response is None: + model_response = "" + elif not isinstance(model_response, str): + # Simple join for now if it's a list + model_response = "".join([p.text for p in model_response if hasattr(p, "text")]) + + ground_truth = row.ground_truth + + if grade_answer: + # Use Tinker's grading logic + is_correct = grade_answer(model_response, str(ground_truth)) + score = 1.0 if is_correct else 0.0 + reason = f"Graded: {is_correct}. GT: {ground_truth}" + else: + # Fallback simple check + score = 0.0 + reason = "Grading function not available" + print("DEBUG: grade_answer is None") + + # DEBUG + # print(f"DEBUG: Score: {score}, Reason: {reason[:100]}") + + row.evaluation_result = EvaluateResult( + score=score, reason=reason, metrics={"accuracy": MetricResult(score=score, reason=reason)} + ) + return row diff --git a/examples/tinker_math_rl/train.py b/examples/tinker_math_rl/train.py new file mode 100644 index 00000000..b107e322 --- /dev/null +++ b/examples/tinker_math_rl/train.py @@ -0,0 +1,209 @@ +import asyncio +import copy +import logging +import os +import sys +from functools import partial +from typing import Literal, Any, Optional + +import chz +from datetime import datetime + +# Add tinker-cookbook to path if not installed +# Assuming the directory structure: +# rft/ +# python-sdk/ +# examples/ +# tinker_math_rl/ +# tinker-cookbook/ +repo_root = os.path.abspath(os.path.join(os.path.dirname(__file__), "../../../")) +tinker_path = os.path.join(repo_root, "tinker-cookbook") +if tinker_path not in sys.path: + sys.path.append(tinker_path) + +from tinker_cookbook import cli_utils, model_info, renderers +from tinker_cookbook.recipes.math_rl.math_env import MathEnv, extract_gsm8k_final_answer +from tinker_cookbook.rl.problem_env import ProblemGroupBuilder +from tinker_cookbook.rl.train import AsyncConfig, Config, main + +# eval_protocol imports +from eval_protocol.adapters.huggingface import create_gsm8k_adapter +from eval_protocol.integrations.tinker_cookbook import create_eval_protocol_dataset_builder, EvalProtocolEvaluator +from eval_protocol.integrations.tinker_rollout_processor import TinkerRolloutProcessor + +# Import test components +from examples.tinker_math_rl.test_gsm8k_eval import test_gsm8k_tinker, get_gsm8k_input_rows + +logger = logging.getLogger(__name__) + + +def gsm8k_row_converter( + row: Any, group_size: int, renderer: renderers.Renderer, convo_prefix: list[renderers.Message] | None +) -> Optional[ProblemGroupBuilder]: + """ + Convert an Eval Protocol EvaluationRow to a Tinker ProblemGroupBuilder for GSM8K. + """ + try: + # Extract problem and answer from EvaluationRow + # row.messages contains the conversation. We assume the last user message is the question. + user_msg = next((msg for msg in reversed(row.messages) if msg.role == "user"), None) + if not user_msg: + return None + + problem = user_msg.content + raw_answer = row.ground_truth + + if not problem or not raw_answer: + return None + + # Extract final answer if it looks like a GSM8K solution (contains ####) + # Otherwise assume it is already the answer + if "####" in raw_answer: + answer = extract_gsm8k_final_answer(raw_answer) + else: + answer = raw_answer + + except Exception as e: + logger.warning(f"Failed to parse row: {e}") + return None + + return ProblemGroupBuilder( + env_thunk=partial(MathEnv, problem, answer, renderer, convo_prefix=convo_prefix), + num_envs=group_size, + ) + + +@chz.chz +class CLIConfig: + """Simple command-line configuration for RL training with Eval Protocol.""" + + # Model configuration + model_name: str = "meta-llama/Llama-3.1-8B-Instruct" + lora_rank: int = 32 + renderer_name: str | None = None + load_checkpoint_path: str | None = None + + # Training hyperparameters + group_size: int = 4 + groups_per_batch: int = 100 + learning_rate: float = 1e-5 + max_tokens: int = 512 # Increased for reasoning + temperature: float = 1.0 + kl_penalty_coef: float = 0.0 + + num_substeps: int = 1 + + # Logging configuration + log_path: str | None = None + wandb_project: str | None = None + wandb_name: str | None = None + compute_post_kl: bool = False + + # Evals + eval_every: int = 20 + + # Checkpointing + save_every: int = 20 + + # Service configuration + base_url: str | None = None + + behavior_if_log_dir_exists: cli_utils.LogdirBehavior = "ask" + + max_steps_off_policy: int | None = None + loss_fn: Literal["importance_sampling", "ppo"] = "importance_sampling" + + # Dataset limits + train_limit: int = 1000 + test_limit: int = 100 + + +async def cli_main(cli_config: CLIConfig): + """Convert CLI config to full config and run training.""" + + # Get tokenizer for stop sequences + renderer_name = cli_config.renderer_name or model_info.get_recommended_renderer_name(cli_config.model_name) + + model_name_slug = cli_config.model_name.replace("/", "-") + run_name = f"ep-gsm8k-{model_name_slug}-{cli_config.lora_rank}rank-{datetime.now().strftime('%Y-%m-%d-%H-%M')}" + + if cli_config.log_path is not None: + log_path = cli_config.log_path + else: + log_path = f"/tmp/tinker-examples/ep_math_rl/{run_name}" + + if cli_config.wandb_name is not None: + wandb_name = cli_config.wandb_name + else: + wandb_name = run_name + + # Create the builder class dynamically using the factory + # We use create_gsm8k_adapter as the adapter factory + # We use MathEnv.standard_fewshot_prefix as the prefix factory + EvalProtocolDatasetBuilder = create_eval_protocol_dataset_builder( + adapter_factory=create_gsm8k_adapter, + row_converter=gsm8k_row_converter, + convo_prefix_factory=MathEnv.standard_fewshot_prefix, + train_limit=cli_config.train_limit, + test_limit=cli_config.test_limit, + ) + + # Create the EvalProtocol Evaluator + # Use the test_limit for the number of rows to evaluate + eval_rows = get_gsm8k_input_rows(limit=cli_config.test_limit) + + # Need to wrap in a factory as expected by Config.evaluator_builders + def create_eval_protocol_evaluator(): + return EvalProtocolEvaluator( + rows=copy.deepcopy(eval_rows), + eval_func=test_gsm8k_tinker, + rollout_processor_cls=TinkerRolloutProcessor, + model_name=cli_config.model_name, + renderer_name=renderer_name, + max_tokens=cli_config.max_tokens, + temperature=0.0, # Greedy for eval + ) + + # Create full config + config = Config( + learning_rate=cli_config.learning_rate, + dataset_builder=EvalProtocolDatasetBuilder( + batch_size=cli_config.groups_per_batch, + model_name_for_tokenizer=cli_config.model_name, + renderer_name=renderer_name, + group_size=cli_config.group_size, + ), + model_name=cli_config.model_name, + lora_rank=cli_config.lora_rank, + max_tokens=cli_config.max_tokens, + temperature=cli_config.temperature, + wandb_project=cli_config.wandb_project, + wandb_name=wandb_name, + log_path=log_path, + base_url=cli_config.base_url, + load_checkpoint_path=cli_config.load_checkpoint_path, + compute_post_kl=cli_config.compute_post_kl, + kl_penalty_coef=cli_config.kl_penalty_coef, + num_substeps=cli_config.num_substeps, + eval_every=cli_config.eval_every, + save_every=cli_config.save_every, + async_config=AsyncConfig( + max_steps_off_policy=cli_config.max_steps_off_policy, + groups_per_batch=cli_config.groups_per_batch, + ) + if cli_config.max_steps_off_policy is not None + else None, + loss_fn=cli_config.loss_fn, + # Add our custom evaluator + evaluator_builders=[create_eval_protocol_evaluator], + ) + + cli_utils.check_log_dir(log_path, behavior_if_exists=cli_config.behavior_if_log_dir_exists) + + # Run training + await main(config) + + +if __name__ == "__main__": + cli_config = chz.entrypoint(CLIConfig) + asyncio.run(cli_main(cli_config))