Skip to content

Commit 29afd31

Browse files
author
Dylan Huang
authored
Generate CLI reference for evalprotocol.io (#375)
* save * update * Enhance CLI documentation generation by updating subparser help extraction. Introduce a method to hide suppressed commands from help output and ensure accurate help text is included for subparsers. * remove generated cli-reference * update
1 parent 6f6afa2 commit 29afd31

File tree

2 files changed

+354
-26
lines changed

2 files changed

+354
-26
lines changed

eval_protocol/cli.py

Lines changed: 82 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -32,9 +32,16 @@
3232
preview_command = None # type: ignore[assignment]
3333

3434

35-
def parse_args(args=None):
36-
"""Parse command line arguments"""
37-
parser = argparse.ArgumentParser(description="eval-protocol: Tools for evaluation and reward modeling")
35+
def build_parser() -> argparse.ArgumentParser:
36+
"""Build and return the argument parser for the CLI."""
37+
parser = argparse.ArgumentParser(
38+
description="Inspect evaluation runs locally, upload evaluators, and create reinforcement fine-tuning jobs on Fireworks"
39+
)
40+
return _configure_parser(parser)
41+
42+
43+
def _configure_parser(parser: argparse.ArgumentParser) -> argparse.ArgumentParser:
44+
"""Configure all arguments and subparsers on the given parser."""
3845
parser.add_argument("--verbose", "-v", action="store_true", help="Enable verbose logging")
3946
parser.add_argument(
4047
"--profile",
@@ -396,39 +403,52 @@ def parse_args(args=None):
396403
rft_parser.add_argument("--base-model", help="Base model resource id")
397404
rft_parser.add_argument("--warm-start-from", help="Addon model to warm start from")
398405
rft_parser.add_argument("--output-model", help="Output model id (defaults from evaluator)")
399-
rft_parser.add_argument("--epochs", type=int, default=1)
400-
rft_parser.add_argument("--batch-size", type=int, default=128000)
401-
rft_parser.add_argument("--learning-rate", type=float, default=3e-5)
402-
rft_parser.add_argument("--max-context-length", type=int, default=65536)
403-
rft_parser.add_argument("--lora-rank", type=int, default=16)
406+
rft_parser.add_argument("--epochs", type=int, default=1, help="Number of training epochs")
407+
rft_parser.add_argument("--batch-size", type=int, default=128000, help="Training batch size in tokens")
408+
rft_parser.add_argument("--learning-rate", type=float, default=3e-5, help="Learning rate for training")
409+
rft_parser.add_argument("--max-context-length", type=int, default=65536, help="Maximum context length in tokens")
410+
rft_parser.add_argument("--lora-rank", type=int, default=16, help="LoRA rank for fine-tuning")
404411
rft_parser.add_argument("--gradient-accumulation-steps", type=int, help="Number of gradient accumulation steps")
405-
rft_parser.add_argument("--learning-rate-warmup-steps", type=int, help="Number of LR warmup steps")
406-
rft_parser.add_argument("--accelerator-count", type=int)
407-
rft_parser.add_argument("--region", help="Fireworks region enum value")
408-
rft_parser.add_argument("--display-name", help="RFT job display name")
409-
rft_parser.add_argument("--evaluation-dataset", help="Optional separate eval dataset id")
410-
rft_parser.add_argument("--eval-auto-carveout", dest="eval_auto_carveout", action="store_true", default=True)
411-
rft_parser.add_argument("--no-eval-auto-carveout", dest="eval_auto_carveout", action="store_false")
412+
rft_parser.add_argument("--learning-rate-warmup-steps", type=int, help="Number of learning rate warmup steps")
413+
rft_parser.add_argument("--accelerator-count", type=int, help="Number of accelerators (GPUs) to use")
414+
rft_parser.add_argument("--region", help="Fireworks region for training")
415+
rft_parser.add_argument("--display-name", help="Display name for the RFT job")
416+
rft_parser.add_argument("--evaluation-dataset", help="Separate dataset id for evaluation")
417+
rft_parser.add_argument(
418+
"--eval-auto-carveout",
419+
dest="eval_auto_carveout",
420+
action="store_true",
421+
default=True,
422+
help="Automatically carve out evaluation data from training set",
423+
)
424+
rft_parser.add_argument(
425+
"--no-eval-auto-carveout",
426+
dest="eval_auto_carveout",
427+
action="store_false",
428+
help="Disable automatic evaluation data carveout",
429+
)
412430
# Rollout chunking
413431
rft_parser.add_argument("--chunk-size", type=int, default=100, help="Data chunk size for rollout batching")
414432
# Inference params
415-
rft_parser.add_argument("--temperature", type=float)
416-
rft_parser.add_argument("--top-p", type=float)
417-
rft_parser.add_argument("--top-k", type=int)
418-
rft_parser.add_argument("--max-output-tokens", type=int, default=32768)
419-
rft_parser.add_argument("--response-candidates-count", type=int, default=8)
433+
rft_parser.add_argument("--temperature", type=float, help="Sampling temperature for rollouts")
434+
rft_parser.add_argument("--top-p", type=float, help="Top-p (nucleus) sampling parameter")
435+
rft_parser.add_argument("--top-k", type=int, help="Top-k sampling parameter")
436+
rft_parser.add_argument("--max-output-tokens", type=int, default=32768, help="Maximum output tokens per rollout")
437+
rft_parser.add_argument(
438+
"--response-candidates-count", type=int, default=8, help="Number of response candidates per prompt"
439+
)
420440
rft_parser.add_argument("--extra-body", help="JSON string for extra inference params")
421441
# MCP server (optional)
422442
rft_parser.add_argument(
423443
"--mcp-server",
424-
help="The MCP server resource name to use for the reinforcement fine-tuning job.",
444+
help="MCP server resource name for agentic rollouts",
425445
)
426446
# Wandb
427-
rft_parser.add_argument("--wandb-enabled", action="store_true")
428-
rft_parser.add_argument("--wandb-project")
429-
rft_parser.add_argument("--wandb-entity")
430-
rft_parser.add_argument("--wandb-run-id")
431-
rft_parser.add_argument("--wandb-api-key")
447+
rft_parser.add_argument("--wandb-enabled", action="store_true", help="Enable Weights & Biases logging")
448+
rft_parser.add_argument("--wandb-project", help="Weights & Biases project name")
449+
rft_parser.add_argument("--wandb-entity", help="Weights & Biases entity (username or team)")
450+
rft_parser.add_argument("--wandb-run-id", help="Weights & Biases run id for resuming")
451+
rft_parser.add_argument("--wandb-api-key", help="Weights & Biases API key")
432452
# Misc
433453
rft_parser.add_argument("--job-id", help="Specify an explicit RFT job id")
434454
rft_parser.add_argument("--yes", "-y", action="store_true", help="Non-interactive mode")
@@ -494,6 +514,38 @@ def parse_args(args=None):
494514
# help="Run an evaluation using a Hydra configuration. All arguments after 'run' are passed to Hydra.",
495515
# )
496516

517+
# Hidden command: export-docs (for generating CLI reference documentation)
518+
export_docs_parser = subparsers.add_parser("export-docs", help=argparse.SUPPRESS)
519+
export_docs_parser.add_argument(
520+
"--output",
521+
"-o",
522+
default="./docs/cli-reference.md",
523+
help="Output markdown file path (default: ./docs/cli-reference.md)",
524+
)
525+
526+
# Update metavar to only show visible commands (exclude those with SUPPRESS)
527+
_hide_suppressed_subparsers(parser)
528+
529+
return parser
530+
531+
532+
def _hide_suppressed_subparsers(parser: argparse.ArgumentParser) -> None:
533+
"""Update subparsers to exclude commands with help=SUPPRESS from help output."""
534+
for action in parser._actions:
535+
if isinstance(action, argparse._SubParsersAction):
536+
# Filter _choices_actions to only visible commands
537+
choices_actions = getattr(action, "_choices_actions", [])
538+
visible_actions = [a for a in choices_actions if a.help != argparse.SUPPRESS]
539+
action._choices_actions = visible_actions
540+
# Update metavar to match
541+
visible_names = [a.dest for a in visible_actions]
542+
if visible_names:
543+
action.metavar = "{" + ",".join(visible_names) + "}"
544+
545+
546+
def parse_args(args=None):
547+
"""Parse command line arguments."""
548+
parser = build_parser()
497549
# Use parse_known_args to allow Hydra to handle its own arguments
498550
return parser.parse_known_args(args)
499551

@@ -623,6 +675,10 @@ def _extract_flag_value(argv_list, flag_name):
623675
from .cli_commands.local_test import local_test_command
624676

625677
return local_test_command(args)
678+
elif args.command == "export-docs":
679+
from .cli_commands.export_docs import export_docs_command
680+
681+
return export_docs_command(args)
626682
# elif args.command == "run":
627683
# # For the 'run' command, Hydra takes over argument parsing.
628684
#

0 commit comments

Comments
 (0)