|
34 | 34 |
|
35 | 35 | def build_parser() -> argparse.ArgumentParser: |
36 | 36 | """Build and return the argument parser for the CLI.""" |
37 | | - parser = argparse.ArgumentParser(description="eval-protocol: Tools for evaluation and reward modeling") |
| 37 | + parser = argparse.ArgumentParser( |
| 38 | + description="Inspect evaluation runs locally, upload evaluators, and create reinforcement fine-tuning jobs on Fireworks" |
| 39 | + ) |
38 | 40 | return _configure_parser(parser) |
39 | 41 |
|
40 | 42 |
|
@@ -401,39 +403,52 @@ def _configure_parser(parser: argparse.ArgumentParser) -> argparse.ArgumentParse |
401 | 403 | rft_parser.add_argument("--base-model", help="Base model resource id") |
402 | 404 | rft_parser.add_argument("--warm-start-from", help="Addon model to warm start from") |
403 | 405 | rft_parser.add_argument("--output-model", help="Output model id (defaults from evaluator)") |
404 | | - rft_parser.add_argument("--epochs", type=int, default=1) |
405 | | - rft_parser.add_argument("--batch-size", type=int, default=128000) |
406 | | - rft_parser.add_argument("--learning-rate", type=float, default=3e-5) |
407 | | - rft_parser.add_argument("--max-context-length", type=int, default=65536) |
408 | | - rft_parser.add_argument("--lora-rank", type=int, default=16) |
| 406 | + rft_parser.add_argument("--epochs", type=int, default=1, help="Number of training epochs") |
| 407 | + rft_parser.add_argument("--batch-size", type=int, default=128000, help="Training batch size in tokens") |
| 408 | + rft_parser.add_argument("--learning-rate", type=float, default=3e-5, help="Learning rate for training") |
| 409 | + rft_parser.add_argument("--max-context-length", type=int, default=65536, help="Maximum context length in tokens") |
| 410 | + rft_parser.add_argument("--lora-rank", type=int, default=16, help="LoRA rank for fine-tuning") |
409 | 411 | rft_parser.add_argument("--gradient-accumulation-steps", type=int, help="Number of gradient accumulation steps") |
410 | | - rft_parser.add_argument("--learning-rate-warmup-steps", type=int, help="Number of LR warmup steps") |
411 | | - rft_parser.add_argument("--accelerator-count", type=int) |
412 | | - rft_parser.add_argument("--region", help="Fireworks region enum value") |
413 | | - rft_parser.add_argument("--display-name", help="RFT job display name") |
414 | | - rft_parser.add_argument("--evaluation-dataset", help="Optional separate eval dataset id") |
415 | | - rft_parser.add_argument("--eval-auto-carveout", dest="eval_auto_carveout", action="store_true", default=True) |
416 | | - rft_parser.add_argument("--no-eval-auto-carveout", dest="eval_auto_carveout", action="store_false") |
| 412 | + rft_parser.add_argument("--learning-rate-warmup-steps", type=int, help="Number of learning rate warmup steps") |
| 413 | + rft_parser.add_argument("--accelerator-count", type=int, help="Number of accelerators (GPUs) to use") |
| 414 | + rft_parser.add_argument("--region", help="Fireworks region for training") |
| 415 | + rft_parser.add_argument("--display-name", help="Display name for the RFT job") |
| 416 | + rft_parser.add_argument("--evaluation-dataset", help="Separate dataset id for evaluation") |
| 417 | + rft_parser.add_argument( |
| 418 | + "--eval-auto-carveout", |
| 419 | + dest="eval_auto_carveout", |
| 420 | + action="store_true", |
| 421 | + default=True, |
| 422 | + help="Automatically carve out evaluation data from training set", |
| 423 | + ) |
| 424 | + rft_parser.add_argument( |
| 425 | + "--no-eval-auto-carveout", |
| 426 | + dest="eval_auto_carveout", |
| 427 | + action="store_false", |
| 428 | + help="Disable automatic evaluation data carveout", |
| 429 | + ) |
417 | 430 | # Rollout chunking |
418 | 431 | rft_parser.add_argument("--chunk-size", type=int, default=100, help="Data chunk size for rollout batching") |
419 | 432 | # Inference params |
420 | | - rft_parser.add_argument("--temperature", type=float) |
421 | | - rft_parser.add_argument("--top-p", type=float) |
422 | | - rft_parser.add_argument("--top-k", type=int) |
423 | | - rft_parser.add_argument("--max-output-tokens", type=int, default=32768) |
424 | | - rft_parser.add_argument("--response-candidates-count", type=int, default=8) |
| 433 | + rft_parser.add_argument("--temperature", type=float, help="Sampling temperature for rollouts") |
| 434 | + rft_parser.add_argument("--top-p", type=float, help="Top-p (nucleus) sampling parameter") |
| 435 | + rft_parser.add_argument("--top-k", type=int, help="Top-k sampling parameter") |
| 436 | + rft_parser.add_argument("--max-output-tokens", type=int, default=32768, help="Maximum output tokens per rollout") |
| 437 | + rft_parser.add_argument( |
| 438 | + "--response-candidates-count", type=int, default=8, help="Number of response candidates per prompt" |
| 439 | + ) |
425 | 440 | rft_parser.add_argument("--extra-body", help="JSON string for extra inference params") |
426 | 441 | # MCP server (optional) |
427 | 442 | rft_parser.add_argument( |
428 | 443 | "--mcp-server", |
429 | | - help="The MCP server resource name to use for the reinforcement fine-tuning job.", |
| 444 | + help="MCP server resource name for agentic rollouts", |
430 | 445 | ) |
431 | 446 | # Wandb |
432 | | - rft_parser.add_argument("--wandb-enabled", action="store_true") |
433 | | - rft_parser.add_argument("--wandb-project") |
434 | | - rft_parser.add_argument("--wandb-entity") |
435 | | - rft_parser.add_argument("--wandb-run-id") |
436 | | - rft_parser.add_argument("--wandb-api-key") |
| 447 | + rft_parser.add_argument("--wandb-enabled", action="store_true", help="Enable Weights & Biases logging") |
| 448 | + rft_parser.add_argument("--wandb-project", help="Weights & Biases project name") |
| 449 | + rft_parser.add_argument("--wandb-entity", help="Weights & Biases entity (username or team)") |
| 450 | + rft_parser.add_argument("--wandb-run-id", help="Weights & Biases run id for resuming") |
| 451 | + rft_parser.add_argument("--wandb-api-key", help="Weights & Biases API key") |
437 | 452 | # Misc |
438 | 453 | rft_parser.add_argument("--job-id", help="Specify an explicit RFT job id") |
439 | 454 | rft_parser.add_argument("--yes", "-y", action="store_true", help="Non-interactive mode") |
|
0 commit comments