-
Notifications
You must be signed in to change notification settings - Fork 16
auto generated cli #384
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
auto generated cli #384
Changes from 2 commits
1df9e72
76e9cec
5885822
ac5be36
35c66c9
2bb176d
c209678
0d0a50c
665ea5e
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -3,11 +3,16 @@ | |
| """ | ||
|
|
||
| import argparse | ||
| import inspect | ||
| import json | ||
| import logging | ||
| import os | ||
| import sys | ||
| from pathlib import Path | ||
| from typing import Any, cast | ||
| from .cli_commands.utils import add_args_from_callable_signature | ||
|
|
||
| from fireworks import Fireworks | ||
|
|
||
| logger = logging.getLogger(__name__) | ||
|
|
||
|
|
@@ -374,87 +379,11 @@ def _configure_parser(parser: argparse.ArgumentParser) -> argparse.ArgumentParse | |
| "rft", | ||
| help="Create a Reinforcement Fine-tuning Job on Fireworks", | ||
| ) | ||
| rft_parser.add_argument( | ||
| "--evaluator", | ||
| help="Evaluator ID or fully-qualified resource (accounts/{acct}/evaluators/{id}); if omitted, derive from local tests", | ||
| ) | ||
| # Dataset options | ||
| rft_parser.add_argument( | ||
| "--dataset", | ||
| help="Use existing dataset (ID or resource 'accounts/{acct}/datasets/{id}') to skip local materialization", | ||
| ) | ||
| rft_parser.add_argument( | ||
| "--dataset-jsonl", | ||
| help="Path to JSONL to upload as a new Fireworks dataset", | ||
| ) | ||
| rft_parser.add_argument( | ||
| "--dataset-builder", | ||
| help="Explicit dataset builder spec (module::function or path::function)", | ||
| ) | ||
| rft_parser.add_argument( | ||
| "--dataset-display-name", | ||
| help="Display name for dataset on Fireworks (defaults to dataset id)", | ||
| ) | ||
| # Training config and evaluator/job settings | ||
| rft_parser.add_argument("--base-model", help="Base model resource id") | ||
| rft_parser.add_argument("--warm-start-from", help="Addon model to warm start from") | ||
| rft_parser.add_argument("--output-model", help="Output model id (defaults from evaluator)") | ||
| rft_parser.add_argument("--epochs", type=int, default=1, help="Number of training epochs") | ||
| rft_parser.add_argument("--batch-size", type=int, default=128000, help="Training batch size in tokens") | ||
| rft_parser.add_argument("--learning-rate", type=float, default=3e-5, help="Learning rate for training") | ||
| rft_parser.add_argument("--max-context-length", type=int, default=65536, help="Maximum context length in tokens") | ||
| rft_parser.add_argument("--lora-rank", type=int, default=16, help="LoRA rank for fine-tuning") | ||
| rft_parser.add_argument("--gradient-accumulation-steps", type=int, help="Number of gradient accumulation steps") | ||
| rft_parser.add_argument("--learning-rate-warmup-steps", type=int, help="Number of learning rate warmup steps") | ||
| rft_parser.add_argument("--accelerator-count", type=int, help="Number of accelerators (GPUs) to use") | ||
| rft_parser.add_argument("--region", help="Fireworks region for training") | ||
| rft_parser.add_argument("--display-name", help="Display name for the RFT job") | ||
| rft_parser.add_argument("--evaluation-dataset", help="Separate dataset id for evaluation") | ||
| rft_parser.add_argument( | ||
| "--eval-auto-carveout", | ||
| dest="eval_auto_carveout", | ||
| action="store_true", | ||
| default=True, | ||
| help="Automatically carve out evaluation data from training set", | ||
| ) | ||
| rft_parser.add_argument( | ||
| "--no-eval-auto-carveout", | ||
| dest="eval_auto_carveout", | ||
| action="store_false", | ||
| help="Disable automatic evaluation data carveout", | ||
| ) | ||
| # Rollout chunking | ||
| rft_parser.add_argument("--chunk-size", type=int, default=100, help="Data chunk size for rollout batching") | ||
| # Inference params | ||
| rft_parser.add_argument("--temperature", type=float, help="Sampling temperature for rollouts") | ||
| rft_parser.add_argument("--top-p", type=float, help="Top-p (nucleus) sampling parameter") | ||
| rft_parser.add_argument("--top-k", type=int, help="Top-k sampling parameter") | ||
| rft_parser.add_argument("--max-output-tokens", type=int, default=32768, help="Maximum output tokens per rollout") | ||
| rft_parser.add_argument( | ||
| "--response-candidates-count", type=int, default=8, help="Number of response candidates per prompt" | ||
| ) | ||
| rft_parser.add_argument("--extra-body", help="JSON string for extra inference params") | ||
| # MCP server (optional) | ||
| rft_parser.add_argument( | ||
| "--mcp-server", | ||
| help="MCP server resource name for agentic rollouts", | ||
| ) | ||
| # Wandb | ||
| rft_parser.add_argument("--wandb-enabled", action="store_true", help="Enable Weights & Biases logging") | ||
| rft_parser.add_argument("--wandb-project", help="Weights & Biases project name") | ||
| rft_parser.add_argument("--wandb-entity", help="Weights & Biases entity (username or team)") | ||
| rft_parser.add_argument("--wandb-run-id", help="Weights & Biases run id for resuming") | ||
| rft_parser.add_argument("--wandb-api-key", help="Weights & Biases API key") | ||
| # Misc | ||
| rft_parser.add_argument("--job-id", help="Specify an explicit RFT job id") | ||
|
|
||
| rft_parser.add_argument("--yes", "-y", action="store_true", help="Non-interactive mode") | ||
| rft_parser.add_argument("--dry-run", action="store_true", help="Print planned REST calls without sending") | ||
| rft_parser.add_argument("--dry-run", action="store_true", help="Print planned SDK call without sending") | ||
| rft_parser.add_argument("--force", action="store_true", help="Overwrite existing evaluator with the same ID") | ||
| rft_parser.add_argument( | ||
| "--skip-validation", | ||
| action="store_true", | ||
| help="Skip local dataset and evaluator validation before creating the RFT job", | ||
| ) | ||
| rft_parser.add_argument("--skip-validation", action="store_true", help="Skip local dataset/evaluator validation") | ||
| rft_parser.add_argument( | ||
| "--ignore-docker", | ||
| action="store_true", | ||
|
|
@@ -463,14 +392,64 @@ def _configure_parser(parser: argparse.ArgumentParser) -> argparse.ArgumentParse | |
| rft_parser.add_argument( | ||
| "--docker-build-extra", | ||
| default="", | ||
| metavar="", | ||
| help="Extra flags to pass to 'docker build' when validating evaluator (quoted string, e.g. \"--no-cache --pull --progress=plain\")", | ||
| ) | ||
| rft_parser.add_argument( | ||
| "--docker-run-extra", | ||
| default="", | ||
| metavar="", | ||
| help="Extra flags to pass to 'docker run' when validating evaluator (quoted string, e.g. \"--env-file .env --memory=8g\")", | ||
| ) | ||
|
|
||
| # Everything below has to manually be maintained, can't be auto-generated | ||
|
xzrderek marked this conversation as resolved.
Outdated
|
||
| rft_parser.add_argument( | ||
| "--source-job", | ||
| metavar="", | ||
| help="The source reinforcement fine-tuning job to copy configuration from. If other flags are set, they will override the source job's configuration.", | ||
| ) | ||
| rft_parser.add_argument( | ||
| "--quiet", | ||
| action="store_true", | ||
| help="If set, only errors will be printed.", | ||
| ) | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Missing CLI arguments for dataset creation workflowThe Additional Locations (1) |
||
| skip_fields = { | ||
| "__top_level__": { | ||
| "extra_headers", | ||
| "extra_query", | ||
| "extra_body", | ||
| "timeout", | ||
| "node_count", | ||
| "display_name", | ||
| "account_id", | ||
| }, | ||
| "loss_config": {"kl_beta", "method"}, | ||
|
xzrderek marked this conversation as resolved.
Outdated
|
||
| "training_config": {"region", "jinja_template"}, | ||
|
xzrderek marked this conversation as resolved.
|
||
| "wandb_config": {"run_id"}, | ||
|
xzrderek marked this conversation as resolved.
|
||
| } | ||
| aliases = { | ||
| "wandb_config.api_key": ["--wandb-api-key"], | ||
| "wandb_config.project": ["--wandb-project"], | ||
| "wandb_config.entity": ["--wandb-entity"], | ||
| "wandb_config.enabled": ["--wandb"], | ||
| "reinforcement_fine_tuning_job_id": ["--job-id"], | ||
| } | ||
|
xzrderek marked this conversation as resolved.
|
||
| help_overrides = { | ||
| "training_config.gradient_accumulation_steps": "The number of batches to accumulate gradients before updating the model parameters. The effective batch size will be batch-size multiplied by this value.", | ||
| "training_config.learning_rate_warmup_steps": "The number of learning rate warmup steps for the reinforcement fine-tuning job.", | ||
| "mcp_server": "The MCP server resource name to use for the reinforcement fine-tuning job. (Optional)", | ||
| } | ||
|
|
||
| create_rft_job_fn = Fireworks().reinforcement_fine_tuning_jobs.create | ||
|
xzrderek marked this conversation as resolved.
|
||
|
|
||
| add_args_from_callable_signature( | ||
| rft_parser, | ||
| create_rft_job_fn, | ||
| skip_fields=skip_fields, | ||
| aliases=aliases, | ||
| help_overrides=help_overrides, | ||
| ) | ||
|
|
||
| # Local test command | ||
| local_test_parser = subparsers.add_parser( | ||
| "local-test", | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.