From c637c34b9a0201f84f5b48f69bd3767d2394d79e Mon Sep 17 00:00:00 2001 From: Benny Chen Date: Thu, 30 Oct 2025 00:52:49 -0700 Subject: [PATCH] add chunk size --- .gitignore | 1 + eval_protocol/cli.py | 2 ++ eval_protocol/cli_commands/create_rft.py | 1 + tests/pytest/gsm8k/test_pytest_math_example.py | 5 +++++ 4 files changed, 9 insertions(+) diff --git a/.gitignore b/.gitignore index d4f0d2df..7e434174 100644 --- a/.gitignore +++ b/.gitignore @@ -242,3 +242,4 @@ package-lock.json package.json tau2-bench *.err +eval-protocol diff --git a/eval_protocol/cli.py b/eval_protocol/cli.py index 125198e1..fc9c4d3f 100644 --- a/eval_protocol/cli.py +++ b/eval_protocol/cli.py @@ -402,6 +402,8 @@ def parse_args(args=None): rft_parser.add_argument("--evaluation-dataset", help="Optional separate eval dataset id") rft_parser.add_argument("--eval-auto-carveout", dest="eval_auto_carveout", action="store_true", default=True) rft_parser.add_argument("--no-eval-auto-carveout", dest="eval_auto_carveout", action="store_false") + # Rollout chunking + rft_parser.add_argument("--chunk-size", type=int, help="Data chunk size for rollout batching") # Inference params rft_parser.add_argument("--temperature", type=float) rft_parser.add_argument("--top-p", type=float) diff --git a/eval_protocol/cli_commands/create_rft.py b/eval_protocol/cli_commands/create_rft.py index 9580340a..7a2fe8c4 100644 --- a/eval_protocol/cli_commands/create_rft.py +++ b/eval_protocol/cli_commands/create_rft.py @@ -379,6 +379,7 @@ def create_rft_command(args) -> int: "trainingConfig": training_config, "inferenceParameters": inference_params or None, "wandbConfig": wandb_config, + "chunkSize": getattr(args, "chunk_size", None), "outputStats": None, "outputMetrics": None, "mcpServer": None, diff --git a/tests/pytest/gsm8k/test_pytest_math_example.py b/tests/pytest/gsm8k/test_pytest_math_example.py index 961ff479..ec940f5c 100644 --- a/tests/pytest/gsm8k/test_pytest_math_example.py +++ b/tests/pytest/gsm8k/test_pytest_math_example.py @@ -4,6 +4,9 @@ import os from eval_protocol.data_loader.jsonl_data_loader import EvaluationRowJsonlDataLoader from typing import List, Dict, Any, Optional +import logging + +logger = logging.getLogger(__name__) def extract_answer_digits(ground_truth: str) -> Optional[str]: @@ -54,6 +57,7 @@ def test_math_dataset(row: EvaluationRow, **kwargs) -> EvaluationRow: EvaluationRow with the evaluation result """ #### Get predicted answer value + logger.info(f"I am beginning to execute GSM8k rollout: {row.execution_metadata.rollout_id}") prediction = extract_answer_digits(str(row.messages[2].content)) gt = extract_answer_digits(str(row.ground_truth)) @@ -77,5 +81,6 @@ def test_math_dataset(row: EvaluationRow, **kwargs) -> EvaluationRow: is_score_valid=True, # Optional: Whether the score is valid, true by default reason=reason, # Optional: The reason for the score ) + logger.info(f"I am done executing GSM8k rollout: {row.execution_metadata.rollout_id}") row.evaluation_result = evaluation_result return row