|
32 | 32 | preview_command = None # type: ignore[assignment] |
33 | 33 |
|
34 | 34 |
|
35 | | -def parse_args(args=None): |
36 | | - """Parse command line arguments""" |
37 | | - parser = argparse.ArgumentParser(description="eval-protocol: Tools for evaluation and reward modeling") |
| 35 | +def build_parser() -> argparse.ArgumentParser: |
| 36 | + """Build and return the argument parser for the CLI.""" |
| 37 | + parser = argparse.ArgumentParser( |
| 38 | + description="Inspect evaluation runs locally, upload evaluators, and create reinforcement fine-tuning jobs on Fireworks" |
| 39 | + ) |
| 40 | + return _configure_parser(parser) |
| 41 | + |
| 42 | + |
| 43 | +def _configure_parser(parser: argparse.ArgumentParser) -> argparse.ArgumentParser: |
| 44 | + """Configure all arguments and subparsers on the given parser.""" |
38 | 45 | parser.add_argument("--verbose", "-v", action="store_true", help="Enable verbose logging") |
39 | 46 | parser.add_argument( |
40 | 47 | "--profile", |
@@ -396,39 +403,52 @@ def parse_args(args=None): |
396 | 403 | rft_parser.add_argument("--base-model", help="Base model resource id") |
397 | 404 | rft_parser.add_argument("--warm-start-from", help="Addon model to warm start from") |
398 | 405 | rft_parser.add_argument("--output-model", help="Output model id (defaults from evaluator)") |
399 | | - rft_parser.add_argument("--epochs", type=int, default=1) |
400 | | - rft_parser.add_argument("--batch-size", type=int, default=128000) |
401 | | - rft_parser.add_argument("--learning-rate", type=float, default=3e-5) |
402 | | - rft_parser.add_argument("--max-context-length", type=int, default=65536) |
403 | | - rft_parser.add_argument("--lora-rank", type=int, default=16) |
| 406 | + rft_parser.add_argument("--epochs", type=int, default=1, help="Number of training epochs") |
| 407 | + rft_parser.add_argument("--batch-size", type=int, default=128000, help="Training batch size in tokens") |
| 408 | + rft_parser.add_argument("--learning-rate", type=float, default=3e-5, help="Learning rate for training") |
| 409 | + rft_parser.add_argument("--max-context-length", type=int, default=65536, help="Maximum context length in tokens") |
| 410 | + rft_parser.add_argument("--lora-rank", type=int, default=16, help="LoRA rank for fine-tuning") |
404 | 411 | rft_parser.add_argument("--gradient-accumulation-steps", type=int, help="Number of gradient accumulation steps") |
405 | | - rft_parser.add_argument("--learning-rate-warmup-steps", type=int, help="Number of LR warmup steps") |
406 | | - rft_parser.add_argument("--accelerator-count", type=int) |
407 | | - rft_parser.add_argument("--region", help="Fireworks region enum value") |
408 | | - rft_parser.add_argument("--display-name", help="RFT job display name") |
409 | | - rft_parser.add_argument("--evaluation-dataset", help="Optional separate eval dataset id") |
410 | | - rft_parser.add_argument("--eval-auto-carveout", dest="eval_auto_carveout", action="store_true", default=True) |
411 | | - rft_parser.add_argument("--no-eval-auto-carveout", dest="eval_auto_carveout", action="store_false") |
| 412 | + rft_parser.add_argument("--learning-rate-warmup-steps", type=int, help="Number of learning rate warmup steps") |
| 413 | + rft_parser.add_argument("--accelerator-count", type=int, help="Number of accelerators (GPUs) to use") |
| 414 | + rft_parser.add_argument("--region", help="Fireworks region for training") |
| 415 | + rft_parser.add_argument("--display-name", help="Display name for the RFT job") |
| 416 | + rft_parser.add_argument("--evaluation-dataset", help="Separate dataset id for evaluation") |
| 417 | + rft_parser.add_argument( |
| 418 | + "--eval-auto-carveout", |
| 419 | + dest="eval_auto_carveout", |
| 420 | + action="store_true", |
| 421 | + default=True, |
| 422 | + help="Automatically carve out evaluation data from training set", |
| 423 | + ) |
| 424 | + rft_parser.add_argument( |
| 425 | + "--no-eval-auto-carveout", |
| 426 | + dest="eval_auto_carveout", |
| 427 | + action="store_false", |
| 428 | + help="Disable automatic evaluation data carveout", |
| 429 | + ) |
412 | 430 | # Rollout chunking |
413 | 431 | rft_parser.add_argument("--chunk-size", type=int, default=100, help="Data chunk size for rollout batching") |
414 | 432 | # Inference params |
415 | | - rft_parser.add_argument("--temperature", type=float) |
416 | | - rft_parser.add_argument("--top-p", type=float) |
417 | | - rft_parser.add_argument("--top-k", type=int) |
418 | | - rft_parser.add_argument("--max-output-tokens", type=int, default=32768) |
419 | | - rft_parser.add_argument("--response-candidates-count", type=int, default=8) |
| 433 | + rft_parser.add_argument("--temperature", type=float, help="Sampling temperature for rollouts") |
| 434 | + rft_parser.add_argument("--top-p", type=float, help="Top-p (nucleus) sampling parameter") |
| 435 | + rft_parser.add_argument("--top-k", type=int, help="Top-k sampling parameter") |
| 436 | + rft_parser.add_argument("--max-output-tokens", type=int, default=32768, help="Maximum output tokens per rollout") |
| 437 | + rft_parser.add_argument( |
| 438 | + "--response-candidates-count", type=int, default=8, help="Number of response candidates per prompt" |
| 439 | + ) |
420 | 440 | rft_parser.add_argument("--extra-body", help="JSON string for extra inference params") |
421 | 441 | # MCP server (optional) |
422 | 442 | rft_parser.add_argument( |
423 | 443 | "--mcp-server", |
424 | | - help="The MCP server resource name to use for the reinforcement fine-tuning job.", |
| 444 | + help="MCP server resource name for agentic rollouts", |
425 | 445 | ) |
426 | 446 | # Wandb |
427 | | - rft_parser.add_argument("--wandb-enabled", action="store_true") |
428 | | - rft_parser.add_argument("--wandb-project") |
429 | | - rft_parser.add_argument("--wandb-entity") |
430 | | - rft_parser.add_argument("--wandb-run-id") |
431 | | - rft_parser.add_argument("--wandb-api-key") |
| 447 | + rft_parser.add_argument("--wandb-enabled", action="store_true", help="Enable Weights & Biases logging") |
| 448 | + rft_parser.add_argument("--wandb-project", help="Weights & Biases project name") |
| 449 | + rft_parser.add_argument("--wandb-entity", help="Weights & Biases entity (username or team)") |
| 450 | + rft_parser.add_argument("--wandb-run-id", help="Weights & Biases run id for resuming") |
| 451 | + rft_parser.add_argument("--wandb-api-key", help="Weights & Biases API key") |
432 | 452 | # Misc |
433 | 453 | rft_parser.add_argument("--job-id", help="Specify an explicit RFT job id") |
434 | 454 | rft_parser.add_argument("--yes", "-y", action="store_true", help="Non-interactive mode") |
@@ -494,6 +514,38 @@ def parse_args(args=None): |
494 | 514 | # help="Run an evaluation using a Hydra configuration. All arguments after 'run' are passed to Hydra.", |
495 | 515 | # ) |
496 | 516 |
|
| 517 | + # Hidden command: export-docs (for generating CLI reference documentation) |
| 518 | + export_docs_parser = subparsers.add_parser("export-docs", help=argparse.SUPPRESS) |
| 519 | + export_docs_parser.add_argument( |
| 520 | + "--output", |
| 521 | + "-o", |
| 522 | + default="./docs/cli-reference.md", |
| 523 | + help="Output markdown file path (default: ./docs/cli-reference.md)", |
| 524 | + ) |
| 525 | + |
| 526 | + # Update metavar to only show visible commands (exclude those with SUPPRESS) |
| 527 | + _hide_suppressed_subparsers(parser) |
| 528 | + |
| 529 | + return parser |
| 530 | + |
| 531 | + |
| 532 | +def _hide_suppressed_subparsers(parser: argparse.ArgumentParser) -> None: |
| 533 | + """Update subparsers to exclude commands with help=SUPPRESS from help output.""" |
| 534 | + for action in parser._actions: |
| 535 | + if isinstance(action, argparse._SubParsersAction): |
| 536 | + # Filter _choices_actions to only visible commands |
| 537 | + choices_actions = getattr(action, "_choices_actions", []) |
| 538 | + visible_actions = [a for a in choices_actions if a.help != argparse.SUPPRESS] |
| 539 | + action._choices_actions = visible_actions |
| 540 | + # Update metavar to match |
| 541 | + visible_names = [a.dest for a in visible_actions] |
| 542 | + if visible_names: |
| 543 | + action.metavar = "{" + ",".join(visible_names) + "}" |
| 544 | + |
| 545 | + |
| 546 | +def parse_args(args=None): |
| 547 | + """Parse command line arguments.""" |
| 548 | + parser = build_parser() |
497 | 549 | # Use parse_known_args to allow Hydra to handle its own arguments |
498 | 550 | return parser.parse_known_args(args) |
499 | 551 |
|
@@ -623,6 +675,10 @@ def _extract_flag_value(argv_list, flag_name): |
623 | 675 | from .cli_commands.local_test import local_test_command |
624 | 676 |
|
625 | 677 | return local_test_command(args) |
| 678 | + elif args.command == "export-docs": |
| 679 | + from .cli_commands.export_docs import export_docs_command |
| 680 | + |
| 681 | + return export_docs_command(args) |
626 | 682 | # elif args.command == "run": |
627 | 683 | # # For the 'run' command, Hydra takes over argument parsing. |
628 | 684 | # |
|
0 commit comments