Skip to content
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
144 changes: 63 additions & 81 deletions eval_protocol/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,16 @@
"""

import argparse
import inspect
import json
import logging
import os
import sys
from pathlib import Path
from typing import Any, cast
from .cli_commands.utils import add_args_from_callable_signature

from fireworks import Fireworks

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -374,87 +379,11 @@ def _configure_parser(parser: argparse.ArgumentParser) -> argparse.ArgumentParse
"rft",
help="Create a Reinforcement Fine-tuning Job on Fireworks",
)
rft_parser.add_argument(
"--evaluator",
help="Evaluator ID or fully-qualified resource (accounts/{acct}/evaluators/{id}); if omitted, derive from local tests",
)
# Dataset options
rft_parser.add_argument(
"--dataset",
help="Use existing dataset (ID or resource 'accounts/{acct}/datasets/{id}') to skip local materialization",
)
rft_parser.add_argument(
"--dataset-jsonl",
help="Path to JSONL to upload as a new Fireworks dataset",
)
rft_parser.add_argument(
"--dataset-builder",
help="Explicit dataset builder spec (module::function or path::function)",
)
rft_parser.add_argument(
"--dataset-display-name",
help="Display name for dataset on Fireworks (defaults to dataset id)",
)
# Training config and evaluator/job settings
rft_parser.add_argument("--base-model", help="Base model resource id")
rft_parser.add_argument("--warm-start-from", help="Addon model to warm start from")
rft_parser.add_argument("--output-model", help="Output model id (defaults from evaluator)")
rft_parser.add_argument("--epochs", type=int, default=1, help="Number of training epochs")
rft_parser.add_argument("--batch-size", type=int, default=128000, help="Training batch size in tokens")
rft_parser.add_argument("--learning-rate", type=float, default=3e-5, help="Learning rate for training")
rft_parser.add_argument("--max-context-length", type=int, default=65536, help="Maximum context length in tokens")
rft_parser.add_argument("--lora-rank", type=int, default=16, help="LoRA rank for fine-tuning")
rft_parser.add_argument("--gradient-accumulation-steps", type=int, help="Number of gradient accumulation steps")
rft_parser.add_argument("--learning-rate-warmup-steps", type=int, help="Number of learning rate warmup steps")
rft_parser.add_argument("--accelerator-count", type=int, help="Number of accelerators (GPUs) to use")
rft_parser.add_argument("--region", help="Fireworks region for training")
rft_parser.add_argument("--display-name", help="Display name for the RFT job")
rft_parser.add_argument("--evaluation-dataset", help="Separate dataset id for evaluation")
rft_parser.add_argument(
"--eval-auto-carveout",
dest="eval_auto_carveout",
action="store_true",
default=True,
help="Automatically carve out evaluation data from training set",
)
rft_parser.add_argument(
"--no-eval-auto-carveout",
dest="eval_auto_carveout",
action="store_false",
help="Disable automatic evaluation data carveout",
)
# Rollout chunking
rft_parser.add_argument("--chunk-size", type=int, default=100, help="Data chunk size for rollout batching")
# Inference params
rft_parser.add_argument("--temperature", type=float, help="Sampling temperature for rollouts")
rft_parser.add_argument("--top-p", type=float, help="Top-p (nucleus) sampling parameter")
rft_parser.add_argument("--top-k", type=int, help="Top-k sampling parameter")
rft_parser.add_argument("--max-output-tokens", type=int, default=32768, help="Maximum output tokens per rollout")
rft_parser.add_argument(
"--response-candidates-count", type=int, default=8, help="Number of response candidates per prompt"
)
rft_parser.add_argument("--extra-body", help="JSON string for extra inference params")
# MCP server (optional)
rft_parser.add_argument(
"--mcp-server",
help="MCP server resource name for agentic rollouts",
)
# Wandb
rft_parser.add_argument("--wandb-enabled", action="store_true", help="Enable Weights & Biases logging")
rft_parser.add_argument("--wandb-project", help="Weights & Biases project name")
rft_parser.add_argument("--wandb-entity", help="Weights & Biases entity (username or team)")
rft_parser.add_argument("--wandb-run-id", help="Weights & Biases run id for resuming")
rft_parser.add_argument("--wandb-api-key", help="Weights & Biases API key")
# Misc
rft_parser.add_argument("--job-id", help="Specify an explicit RFT job id")

rft_parser.add_argument("--yes", "-y", action="store_true", help="Non-interactive mode")
rft_parser.add_argument("--dry-run", action="store_true", help="Print planned REST calls without sending")
rft_parser.add_argument("--dry-run", action="store_true", help="Print planned SDK call without sending")
rft_parser.add_argument("--force", action="store_true", help="Overwrite existing evaluator with the same ID")
rft_parser.add_argument(
"--skip-validation",
action="store_true",
help="Skip local dataset and evaluator validation before creating the RFT job",
)
rft_parser.add_argument("--skip-validation", action="store_true", help="Skip local dataset/evaluator validation")
rft_parser.add_argument(
"--ignore-docker",
action="store_true",
Expand All @@ -463,14 +392,64 @@ def _configure_parser(parser: argparse.ArgumentParser) -> argparse.ArgumentParse
rft_parser.add_argument(
"--docker-build-extra",
default="",
metavar="",
Comment thread
xzrderek marked this conversation as resolved.
help="Extra flags to pass to 'docker build' when validating evaluator (quoted string, e.g. \"--no-cache --pull --progress=plain\")",
)
rft_parser.add_argument(
"--docker-run-extra",
default="",
metavar="",
help="Extra flags to pass to 'docker run' when validating evaluator (quoted string, e.g. \"--env-file .env --memory=8g\")",
)

# Everything below has to manually be maintained, can't be auto-generated
Comment thread
xzrderek marked this conversation as resolved.
Outdated
rft_parser.add_argument(
"--source-job",
metavar="",
help="The source reinforcement fine-tuning job to copy configuration from. If other flags are set, they will override the source job's configuration.",
)
rft_parser.add_argument(
"--quiet",
action="store_true",
help="If set, only errors will be printed.",
)
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Missing CLI arguments for dataset creation workflow

The --dataset-jsonl, --dataset-builder, and --dataset-display-name arguments were removed from the CLI but the code in create_rft.py still expects them via getattr(args, "dataset_jsonl", None) and getattr(args, "dataset_display_name", None). These are workflow-specific arguments (not SDK parameters) for creating datasets from local JSONL files. Users attempting to use --dataset-jsonl will get "unrecognized arguments" errors, and the dataset creation workflow from local files is broken. The comment at line 405-406 acknowledges that workflow controls must be maintained manually, but these arguments were not added.

Additional Locations (1)

Fix in Cursor Fix in Web

skip_fields = {
"__top_level__": {
"extra_headers",
"extra_query",
"extra_body",
"timeout",
"node_count",
"display_name",
"account_id",
},
"loss_config": {"kl_beta", "method"},
Comment thread
xzrderek marked this conversation as resolved.
Outdated
"training_config": {"region", "jinja_template"},
Comment thread
xzrderek marked this conversation as resolved.
"wandb_config": {"run_id"},
Comment thread
xzrderek marked this conversation as resolved.
}
aliases = {
"wandb_config.api_key": ["--wandb-api-key"],
"wandb_config.project": ["--wandb-project"],
"wandb_config.entity": ["--wandb-entity"],
"wandb_config.enabled": ["--wandb"],
"reinforcement_fine_tuning_job_id": ["--job-id"],
}
Comment thread
xzrderek marked this conversation as resolved.
help_overrides = {
"training_config.gradient_accumulation_steps": "The number of batches to accumulate gradients before updating the model parameters. The effective batch size will be batch-size multiplied by this value.",
"training_config.learning_rate_warmup_steps": "The number of learning rate warmup steps for the reinforcement fine-tuning job.",
"mcp_server": "The MCP server resource name to use for the reinforcement fine-tuning job. (Optional)",
}

create_rft_job_fn = Fireworks().reinforcement_fine_tuning_jobs.create
Comment thread
xzrderek marked this conversation as resolved.

add_args_from_callable_signature(
rft_parser,
create_rft_job_fn,
skip_fields=skip_fields,
aliases=aliases,
help_overrides=help_overrides,
)

# Local test command
local_test_parser = subparsers.add_parser(
"local-test",
Expand Down Expand Up @@ -542,8 +521,11 @@ def _hide_suppressed_subparsers(parser: argparse.ArgumentParser) -> None:
def parse_args(args=None):
"""Parse command line arguments."""
parser = build_parser()
# Use parse_known_args to allow Hydra to handle its own arguments
return parser.parse_known_args(args)
# Fail fast on unknown flags so typos don't silently get ignored.
parsed, remaining = parser.parse_known_args(args)
if remaining:
parser.error(f"unrecognized arguments: {' '.join(remaining)}")
return parsed, remaining
Comment thread
xzrderek marked this conversation as resolved.
Comment thread
cursor[bot] marked this conversation as resolved.


def main():
Expand Down
14 changes: 8 additions & 6 deletions eval_protocol/cli_commands/create_rft.py
Original file line number Diff line number Diff line change
Expand Up @@ -681,13 +681,13 @@ def _create_rft_job(
return 1

wandb_config: Optional[Dict[str, Any]] = None
if getattr(args, "wandb_enabled", False):
if getattr(args, "enabled", False):
Comment thread
xzrderek marked this conversation as resolved.
Outdated
Comment thread
xzrderek marked this conversation as resolved.
Outdated
Comment thread
xzrderek marked this conversation as resolved.
Outdated
wandb_config = {
"enabled": True,
"apiKey": getattr(args, "wandb_api_key", None),
"project": getattr(args, "wandb_project", None),
"entity": getattr(args, "wandb_entity", None),
"runId": getattr(args, "wandb_run_id", None),
"apiKey": getattr(args, "api_key", None),
"project": getattr(args, "project", None),
"entity": getattr(args, "entity", None),
"runId": getattr(args, "run_id", None),
Comment thread
xzrderek marked this conversation as resolved.
Outdated
Comment thread
xzrderek marked this conversation as resolved.
Outdated
}

body: Dict[str, Any] = {
Expand All @@ -702,7 +702,9 @@ def _create_rft_job(
"outputStats": None,
"outputMetrics": None,
"mcpServer": getattr(args, "mcp_server", None),
"jobId": getattr(args, "job_id", None),
"jobId": getattr(args, "reinforcement_fine_tuning_job_id", None),
"sourceJob": getattr(args, "source_job", None),
Comment thread
xzrderek marked this conversation as resolved.
Outdated
"quiet": getattr(args, "quiet", False),
}
# Debug: print minimal summary
print(f"Prepared RFT job for evaluator '{evaluator_id}' using dataset '{dataset_id}'")
Expand Down
Loading
Loading