Skip to content

Commit cefc461

Browse files
author
Dylan Huang
committed
update
1 parent a1d4cd5 commit cefc461

File tree

2 files changed

+139
-60
lines changed

2 files changed

+139
-60
lines changed

eval_protocol/cli.py

Lines changed: 39 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,9 @@
3434

3535
def build_parser() -> argparse.ArgumentParser:
3636
"""Build and return the argument parser for the CLI."""
37-
parser = argparse.ArgumentParser(description="eval-protocol: Tools for evaluation and reward modeling")
37+
parser = argparse.ArgumentParser(
38+
description="Inspect evaluation runs locally, upload evaluators, and create reinforcement fine-tuning jobs on Fireworks"
39+
)
3840
return _configure_parser(parser)
3941

4042

@@ -401,39 +403,52 @@ def _configure_parser(parser: argparse.ArgumentParser) -> argparse.ArgumentParse
401403
rft_parser.add_argument("--base-model", help="Base model resource id")
402404
rft_parser.add_argument("--warm-start-from", help="Addon model to warm start from")
403405
rft_parser.add_argument("--output-model", help="Output model id (defaults from evaluator)")
404-
rft_parser.add_argument("--epochs", type=int, default=1)
405-
rft_parser.add_argument("--batch-size", type=int, default=128000)
406-
rft_parser.add_argument("--learning-rate", type=float, default=3e-5)
407-
rft_parser.add_argument("--max-context-length", type=int, default=65536)
408-
rft_parser.add_argument("--lora-rank", type=int, default=16)
406+
rft_parser.add_argument("--epochs", type=int, default=1, help="Number of training epochs")
407+
rft_parser.add_argument("--batch-size", type=int, default=128000, help="Training batch size in tokens")
408+
rft_parser.add_argument("--learning-rate", type=float, default=3e-5, help="Learning rate for training")
409+
rft_parser.add_argument("--max-context-length", type=int, default=65536, help="Maximum context length in tokens")
410+
rft_parser.add_argument("--lora-rank", type=int, default=16, help="LoRA rank for fine-tuning")
409411
rft_parser.add_argument("--gradient-accumulation-steps", type=int, help="Number of gradient accumulation steps")
410-
rft_parser.add_argument("--learning-rate-warmup-steps", type=int, help="Number of LR warmup steps")
411-
rft_parser.add_argument("--accelerator-count", type=int)
412-
rft_parser.add_argument("--region", help="Fireworks region enum value")
413-
rft_parser.add_argument("--display-name", help="RFT job display name")
414-
rft_parser.add_argument("--evaluation-dataset", help="Optional separate eval dataset id")
415-
rft_parser.add_argument("--eval-auto-carveout", dest="eval_auto_carveout", action="store_true", default=True)
416-
rft_parser.add_argument("--no-eval-auto-carveout", dest="eval_auto_carveout", action="store_false")
412+
rft_parser.add_argument("--learning-rate-warmup-steps", type=int, help="Number of learning rate warmup steps")
413+
rft_parser.add_argument("--accelerator-count", type=int, help="Number of accelerators (GPUs) to use")
414+
rft_parser.add_argument("--region", help="Fireworks region for training")
415+
rft_parser.add_argument("--display-name", help="Display name for the RFT job")
416+
rft_parser.add_argument("--evaluation-dataset", help="Separate dataset id for evaluation")
417+
rft_parser.add_argument(
418+
"--eval-auto-carveout",
419+
dest="eval_auto_carveout",
420+
action="store_true",
421+
default=True,
422+
help="Automatically carve out evaluation data from training set",
423+
)
424+
rft_parser.add_argument(
425+
"--no-eval-auto-carveout",
426+
dest="eval_auto_carveout",
427+
action="store_false",
428+
help="Disable automatic evaluation data carveout",
429+
)
417430
# Rollout chunking
418431
rft_parser.add_argument("--chunk-size", type=int, default=100, help="Data chunk size for rollout batching")
419432
# Inference params
420-
rft_parser.add_argument("--temperature", type=float)
421-
rft_parser.add_argument("--top-p", type=float)
422-
rft_parser.add_argument("--top-k", type=int)
423-
rft_parser.add_argument("--max-output-tokens", type=int, default=32768)
424-
rft_parser.add_argument("--response-candidates-count", type=int, default=8)
433+
rft_parser.add_argument("--temperature", type=float, help="Sampling temperature for rollouts")
434+
rft_parser.add_argument("--top-p", type=float, help="Top-p (nucleus) sampling parameter")
435+
rft_parser.add_argument("--top-k", type=int, help="Top-k sampling parameter")
436+
rft_parser.add_argument("--max-output-tokens", type=int, default=32768, help="Maximum output tokens per rollout")
437+
rft_parser.add_argument(
438+
"--response-candidates-count", type=int, default=8, help="Number of response candidates per prompt"
439+
)
425440
rft_parser.add_argument("--extra-body", help="JSON string for extra inference params")
426441
# MCP server (optional)
427442
rft_parser.add_argument(
428443
"--mcp-server",
429-
help="The MCP server resource name to use for the reinforcement fine-tuning job.",
444+
help="MCP server resource name for agentic rollouts",
430445
)
431446
# Wandb
432-
rft_parser.add_argument("--wandb-enabled", action="store_true")
433-
rft_parser.add_argument("--wandb-project")
434-
rft_parser.add_argument("--wandb-entity")
435-
rft_parser.add_argument("--wandb-run-id")
436-
rft_parser.add_argument("--wandb-api-key")
447+
rft_parser.add_argument("--wandb-enabled", action="store_true", help="Enable Weights & Biases logging")
448+
rft_parser.add_argument("--wandb-project", help="Weights & Biases project name")
449+
rft_parser.add_argument("--wandb-entity", help="Weights & Biases entity (username or team)")
450+
rft_parser.add_argument("--wandb-run-id", help="Weights & Biases run id for resuming")
451+
rft_parser.add_argument("--wandb-api-key", help="Weights & Biases API key")
437452
# Misc
438453
rft_parser.add_argument("--job-id", help="Specify an explicit RFT job id")
439454
rft_parser.add_argument("--yes", "-y", action="store_true", help="Non-interactive mode")

eval_protocol/cli_commands/export_docs.py

Lines changed: 100 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -59,39 +59,83 @@ def _get_parser_info(parser: argparse.ArgumentParser, subparser_help: str = "")
5959
return info
6060

6161

62-
def _format_argument_row(arg: Dict) -> str:
63-
"""Format a single argument as a markdown table row."""
64-
# Build the flag/argument name
62+
def _format_argument_item(arg: Dict) -> List[str]:
63+
"""Format a single argument as a Mintlify ParamField component."""
64+
lines = []
65+
66+
# Build the flag name
6567
if arg["option_strings"]:
66-
name = ", ".join(f"`{opt}`" for opt in arg["option_strings"])
68+
long_opts = [o for o in arg["option_strings"] if o.startswith("--")]
69+
short_opts = [o for o in arg["option_strings"] if not o.startswith("--")]
70+
primary = long_opts[0] if long_opts else arg["option_strings"][0]
6771
else:
68-
name = f"`{arg['dest']}`"
72+
primary = arg["dest"]
73+
short_opts = []
6974

70-
# Build type info
75+
# Map Python types to ParamField types
7176
type_str = ""
7277
if arg["type"]:
73-
type_str = getattr(arg["type"], "__name__", str(arg["type"]))
74-
if arg["choices"]:
75-
type_str = f"choices: {arg['choices']}"
76-
77-
# Format default value
78+
python_type = getattr(arg["type"], "__name__", str(arg["type"]))
79+
type_map = {"int": "number", "float": "number", "str": "string", "bool": "boolean"}
80+
type_str = type_map.get(python_type, python_type)
81+
elif arg["default"] is not None:
82+
# Infer type from default
83+
if isinstance(arg["default"], bool):
84+
type_str = "boolean"
85+
elif isinstance(arg["default"], int):
86+
type_str = "number"
87+
elif isinstance(arg["default"], float):
88+
type_str = "number"
89+
elif isinstance(arg["default"], str):
90+
type_str = "string"
91+
92+
# Build ParamField attributes
93+
attrs = [f'path="{primary}"']
94+
95+
if type_str:
96+
attrs.append(f'type="{type_str}"')
97+
98+
# Default value
7899
default = arg["default"]
79-
if default is None:
80-
default_str = "-"
81-
elif default == argparse.SUPPRESS:
82-
default_str = "-"
83-
elif isinstance(default, bool):
84-
default_str = str(default).lower()
85-
else:
86-
default_str = f"`{default}`"
100+
if default is not None and default != argparse.SUPPRESS:
101+
if isinstance(default, bool):
102+
default_str = str(default).lower()
103+
elif isinstance(default, str):
104+
# Escape quotes in string defaults
105+
default_str = default.replace('"', '\\"')
106+
else:
107+
default_str = str(default)
108+
attrs.append(f'default="{default_str}"')
109+
110+
if arg["required"]:
111+
attrs.append("required")
112+
113+
# Build description with short alias mention
114+
help_text = (arg["help"] or "").replace("<", "&lt;").replace(">", "&gt;")
115+
if short_opts:
116+
alias_note = f"Short: `{short_opts[0]}`"
117+
if help_text:
118+
help_text = f"{help_text} ({alias_note})"
119+
else:
120+
help_text = alias_note
87121

88-
# Help text (escape pipe characters for markdown tables)
89-
help_text = (arg["help"] or "-").replace("|", "\\|")
122+
# Add choices info to description
123+
if arg["choices"]:
124+
choices_str = ", ".join(f"`{c}`" for c in arg["choices"])
125+
choices_note = f"Choices: {choices_str}"
126+
if help_text:
127+
help_text = f"{help_text}. {choices_note}"
128+
else:
129+
help_text = choices_note
90130

91-
# Required indicator
92-
required = "Yes" if arg["required"] else "No"
131+
# Generate ParamField
132+
lines.append(f"<ParamField {' '.join(attrs)}>")
133+
if help_text:
134+
lines.append(f" {help_text}")
135+
lines.append("</ParamField>")
136+
lines.append("")
93137

94-
return f"| {name} | {type_str} | {default_str} | {required} | {help_text} |"
138+
return lines
95139

96140

97141
def _generate_command_section(
@@ -105,6 +149,21 @@ def _generate_command_section(
105149
full_command = f"{parent_command} {name}".strip()
106150
heading = "#" * heading_level
107151

152+
# Skip commands that have no arguments and only subparsers (like "ep create")
153+
# Instead, just render the subcommands directly at the same level
154+
if not info["arguments"] and info["subparsers"]:
155+
# Skip this level, render subcommands directly
156+
for subname, subinfo in info["subparsers"].items():
157+
lines.extend(
158+
_generate_command_section(
159+
subname,
160+
subinfo,
161+
full_command,
162+
heading_level, # Keep same heading level
163+
)
164+
)
165+
return lines
166+
108167
lines.append(f"{heading} `{full_command}`")
109168
lines.append("")
110169

@@ -114,13 +173,10 @@ def _generate_command_section(
114173
lines.append(description)
115174
lines.append("")
116175

117-
# Arguments table
176+
# Arguments (no extra heading to keep TOC clean)
118177
if info["arguments"]:
119-
lines.append("| Option | Type | Default | Required | Description |")
120-
lines.append("|--------|------|---------|----------|-------------|")
121178
for arg in info["arguments"]:
122-
lines.append(_format_argument_row(arg))
123-
lines.append("")
179+
lines.extend(_format_argument_item(arg))
124180

125181
# Handle nested subparsers recursively
126182
if info["subparsers"]:
@@ -162,22 +218,30 @@ def generate_cli_docs(parser: argparse.ArgumentParser, output_path: str) -> int:
162218
if name != "export-docs" # Don't document the hidden command
163219
}
164220

165-
# Generate single page
221+
# Generate single page with Mintlify frontmatter
166222
lines = []
167-
lines.append("# CLI Reference")
223+
lines.append("---")
224+
lines.append("title: CLI")
225+
lines.append("icon: terminal")
226+
lines.append("---")
168227
lines.append("")
169-
lines.append(f"**{info['prog']}** - {info['description']}")
228+
lines.append(
229+
f"The `{info['prog']}` command-line interface can {info['description'][0].lower()}{info['description'][1:]}."
230+
)
231+
lines.append("")
232+
lines.append("```bash")
233+
lines.append(f"{info['prog']} [global options] <command> [command options]")
234+
lines.append("```")
170235
lines.append("")
171236

172237
# Global options
173238
if info["arguments"]:
174239
lines.append("## Global Options")
175240
lines.append("")
176-
lines.append("| Option | Type | Default | Required | Description |")
177-
lines.append("|--------|------|---------|----------|-------------|")
178-
for arg in info["arguments"]:
179-
lines.append(_format_argument_row(arg))
241+
lines.append("These options can be used with any command:")
180242
lines.append("")
243+
for arg in info["arguments"]:
244+
lines.extend(_format_argument_item(arg))
181245

182246
# Commands section
183247
if visible_subparsers:

0 commit comments

Comments
 (0)