Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
137 changes: 58 additions & 79 deletions eval_protocol/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,16 @@
"""

import argparse
import inspect
import json
import logging
import os
import sys
from pathlib import Path
from typing import Any, cast
from .cli_commands.utils import add_args_from_callable_signature

from fireworks import Fireworks

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -374,87 +379,11 @@ def _configure_parser(parser: argparse.ArgumentParser) -> argparse.ArgumentParse
"rft",
help="Create a Reinforcement Fine-tuning Job on Fireworks",
)
rft_parser.add_argument(
"--evaluator",
help="Evaluator ID or fully-qualified resource (accounts/{acct}/evaluators/{id}); if omitted, derive from local tests",
)
# Dataset options
rft_parser.add_argument(
"--dataset",
help="Use existing dataset (ID or resource 'accounts/{acct}/datasets/{id}') to skip local materialization",
)
rft_parser.add_argument(
"--dataset-jsonl",
help="Path to JSONL to upload as a new Fireworks dataset",
)
rft_parser.add_argument(
"--dataset-builder",
help="Explicit dataset builder spec (module::function or path::function)",
)
rft_parser.add_argument(
"--dataset-display-name",
help="Display name for dataset on Fireworks (defaults to dataset id)",
)
# Training config and evaluator/job settings
rft_parser.add_argument("--base-model", help="Base model resource id")
rft_parser.add_argument("--warm-start-from", help="Addon model to warm start from")
rft_parser.add_argument("--output-model", help="Output model id (defaults from evaluator)")
rft_parser.add_argument("--epochs", type=int, default=1, help="Number of training epochs")
rft_parser.add_argument("--batch-size", type=int, default=128000, help="Training batch size in tokens")
rft_parser.add_argument("--learning-rate", type=float, default=3e-5, help="Learning rate for training")
rft_parser.add_argument("--max-context-length", type=int, default=65536, help="Maximum context length in tokens")
rft_parser.add_argument("--lora-rank", type=int, default=16, help="LoRA rank for fine-tuning")
rft_parser.add_argument("--gradient-accumulation-steps", type=int, help="Number of gradient accumulation steps")
rft_parser.add_argument("--learning-rate-warmup-steps", type=int, help="Number of learning rate warmup steps")
rft_parser.add_argument("--accelerator-count", type=int, help="Number of accelerators (GPUs) to use")
rft_parser.add_argument("--region", help="Fireworks region for training")
rft_parser.add_argument("--display-name", help="Display name for the RFT job")
rft_parser.add_argument("--evaluation-dataset", help="Separate dataset id for evaluation")
rft_parser.add_argument(
"--eval-auto-carveout",
dest="eval_auto_carveout",
action="store_true",
default=True,
help="Automatically carve out evaluation data from training set",
)
rft_parser.add_argument(
"--no-eval-auto-carveout",
dest="eval_auto_carveout",
action="store_false",
help="Disable automatic evaluation data carveout",
)
# Rollout chunking
rft_parser.add_argument("--chunk-size", type=int, default=100, help="Data chunk size for rollout batching")
# Inference params
rft_parser.add_argument("--temperature", type=float, help="Sampling temperature for rollouts")
rft_parser.add_argument("--top-p", type=float, help="Top-p (nucleus) sampling parameter")
rft_parser.add_argument("--top-k", type=int, help="Top-k sampling parameter")
rft_parser.add_argument("--max-output-tokens", type=int, default=32768, help="Maximum output tokens per rollout")
rft_parser.add_argument(
"--response-candidates-count", type=int, default=8, help="Number of response candidates per prompt"
)
rft_parser.add_argument("--extra-body", help="JSON string for extra inference params")
# MCP server (optional)
rft_parser.add_argument(
"--mcp-server",
help="MCP server resource name for agentic rollouts",
)
# Wandb
rft_parser.add_argument("--wandb-enabled", action="store_true", help="Enable Weights & Biases logging")
rft_parser.add_argument("--wandb-project", help="Weights & Biases project name")
rft_parser.add_argument("--wandb-entity", help="Weights & Biases entity (username or team)")
rft_parser.add_argument("--wandb-run-id", help="Weights & Biases run id for resuming")
rft_parser.add_argument("--wandb-api-key", help="Weights & Biases API key")
# Misc
rft_parser.add_argument("--job-id", help="Specify an explicit RFT job id")

rft_parser.add_argument("--yes", "-y", action="store_true", help="Non-interactive mode")
rft_parser.add_argument("--dry-run", action="store_true", help="Print planned REST calls without sending")
rft_parser.add_argument("--dry-run", action="store_true", help="Print planned SDK call without sending")
rft_parser.add_argument("--force", action="store_true", help="Overwrite existing evaluator with the same ID")
rft_parser.add_argument(
"--skip-validation",
action="store_true",
help="Skip local dataset and evaluator validation before creating the RFT job",
)
rft_parser.add_argument("--skip-validation", action="store_true", help="Skip local dataset/evaluator validation")
rft_parser.add_argument(
"--ignore-docker",
action="store_true",
Expand All @@ -463,14 +392,64 @@ def _configure_parser(parser: argparse.ArgumentParser) -> argparse.ArgumentParse
rft_parser.add_argument(
"--docker-build-extra",
default="",
metavar="",
Comment thread
xzrderek marked this conversation as resolved.
help="Extra flags to pass to 'docker build' when validating evaluator (quoted string, e.g. \"--no-cache --pull --progress=plain\")",
)
rft_parser.add_argument(
"--docker-run-extra",
default="",
metavar="",
help="Extra flags to pass to 'docker run' when validating evaluator (quoted string, e.g. \"--env-file .env --memory=8g\")",
)

# Everything below has to manually be maintained, can't be auto-generated
Comment thread
xzrderek marked this conversation as resolved.
Outdated
rft_parser.add_argument(
"--source-job",
metavar="",
help="The source reinforcement fine-tuning job to copy configuration from. If other flags are set, they will override the source job's configuration.",
)
rft_parser.add_argument(
"--quiet",
action="store_true",
help="If set, only errors will be printed.",
)
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Missing CLI arguments for dataset creation workflow

The --dataset-jsonl, --dataset-builder, and --dataset-display-name arguments were removed from the CLI but the code in create_rft.py still expects them via getattr(args, "dataset_jsonl", None) and getattr(args, "dataset_display_name", None). These are workflow-specific arguments (not SDK parameters) for creating datasets from local JSONL files. Users attempting to use --dataset-jsonl will get "unrecognized arguments" errors, and the dataset creation workflow from local files is broken. The comment at line 405-406 acknowledges that workflow controls must be maintained manually, but these arguments were not added.

Additional Locations (1)

Fix in Cursor Fix in Web

skip_fields = {
"__top_level__": {
"extra_headers",
"extra_query",
"extra_body",
"timeout",
"node_count",
"display_name",
"account_id",
},
"loss_config": {"kl_beta", "method"},
Comment thread
xzrderek marked this conversation as resolved.
Outdated
"training_config": {"region", "jinja_template"},
Comment thread
xzrderek marked this conversation as resolved.
"wandb_config": {"run_id"},
Comment thread
xzrderek marked this conversation as resolved.
}
aliases = {
"wandb_config.api_key": ["--wandb-api-key"],
"wandb_config.project": ["--wandb-project"],
"wandb_config.entity": ["--wandb-entity"],
"wandb_config.enabled": ["--wandb"],
"reinforcement_fine_tuning_job_id": ["--job-id"],
}
Comment thread
xzrderek marked this conversation as resolved.
help_overrides = {
"training_config.gradient_accumulation_steps": "The number of batches to accumulate gradients before updating the model parameters. The effective batch size will be batch-size multiplied by this value.",
"training_config.learning_rate_warmup_steps": "The number of learning rate warmup steps for the reinforcement fine-tuning job.",
"mcp_server": "The MCP server resource name to use for the reinforcement fine-tuning job. (Optional)",
}

create_rft_job_fn = Fireworks().reinforcement_fine_tuning_jobs.create
Comment thread
xzrderek marked this conversation as resolved.

add_args_from_callable_signature(
rft_parser,
create_rft_job_fn,
skip_fields=skip_fields,
aliases=aliases,
help_overrides=help_overrides,
)

# Local test command
local_test_parser = subparsers.add_parser(
"local-test",
Expand Down
204 changes: 201 additions & 3 deletions eval_protocol/cli_commands/utils.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,17 @@
import os
import ast
import sys
import time
import inspect
import subprocess
import argparse
import typing
import types
from dataclasses import dataclass
from pathlib import Path
from typing import Any, Dict, Iterable, List, Optional, Tuple

from typing import Any, List, Optional, is_typeddict
import typing_extensions
import inspect
from collections.abc import Callable
import pytest

from ..auth import (
Expand Down Expand Up @@ -505,3 +510,196 @@ def _build_entry_point(project_root: str, source_file_path: Optional[str], func_
return f"{rel}::{func_name}"
# Fallback: use filename only
return f"{func_name}.py::{func_name}"


def unwrap_union(tp):
origin = typing.get_origin(tp)

# Handles both typing.Union[...] and PEP604 unions (A | B)
if origin is typing.Union or origin is types.UnionType:
args = [a for a in typing.get_args(tp) if getattr(a, "__name__", "") != "Omit" and a is not type(None)]
return args[0] if args else None

return tp


def argparse_type_from_hint(t: Any) -> Any:
"""Return a callable argparse type for a type hint (minimal unwrapping + fallback).

- Drops Omit/None from unions
- Unwraps Annotated[T, ...] => T
- Falls back to str when the result isn't callable
"""
t = unwrap_union(t)
if typing.get_origin(t) is typing.Annotated:
args = typing.get_args(t)
t = args[0] if args else str
return t if callable(t) else str


def typed_dict_field_docs(typed_dict_cls: type) -> dict[str, str]:
"""
Extract per-field docstrings from a TypedDict class that uses the pattern:

field: Type
'doc...'

Returns { "field": "doc..." }
"""
try:
src = inspect.getsource(typed_dict_cls)
except Exception:
return {}

try:
mod = ast.parse(src)
except SyntaxError:
return {}

# find the class definition
cls_node = None
for node in mod.body:
if isinstance(node, ast.ClassDef) and node.name == typed_dict_cls.__name__:
cls_node = node
break
if cls_node is None:
return {}

docs: dict[str, str] = {}
body = cls_node.body

i = 0
while i < len(body):
node = body[i]

# field: Annotated[...] or field: T
if isinstance(node, ast.AnnAssign) and isinstance(node.target, ast.Name):
field_name = node.target.id

# next node is a string literal expression => treat as "field doc"
if i + 1 < len(body):
nxt = body[i + 1]
if (
isinstance(nxt, ast.Expr)
and isinstance(nxt.value, ast.Constant)
and isinstance(nxt.value.value, str)
):
docs[field_name] = nxt.value.value.strip()
i += 2
continue

i += 1

return docs


def _parse_args_section_from_doc(doc: str) -> dict[str, str]:
if not doc:
return {}

lines = doc.splitlines()

# find "Args:"
try:
start = next(i for i, line in enumerate(lines) if line.strip() == "Args:")
except StopIteration:
return {}

out: dict[str, str] = {}
cur_name: str | None = None
cur_parts: list[str] = []

for line in lines[start + 1 :]:
# stop if we hit another top-level section header like "Returns:"
if line and not line.startswith(" ") and line.endswith(":"):
break

if not line.strip():
continue

stripped = line.strip()

# New arg header like "dataset: blah"
if ":" in stripped:
name, rest = stripped.split(":", 1)
name = name.strip()
if name.replace("_", "").isalnum():
if cur_name:
out[cur_name] = " ".join(cur_parts).strip()
cur_name = name
cur_parts = [rest.strip()]
continue

# Continuation
if cur_name:
cur_parts.append(stripped)

if cur_name:
out[cur_name] = " ".join(cur_parts).strip()

return out


def _add_flag(
parser: argparse.ArgumentParser,
flags: list[str],
hint: Any,
help_text: str | None,
) -> None:
if unwrap_union(hint) is bool:
parser.add_argument(*flags, action="store_true", help=help_text)
return
Comment thread
xzrderek marked this conversation as resolved.
parser.add_argument(
*flags,
type=argparse_type_from_hint(hint),
help=help_text,
metavar="",
)


def add_args_from_callable_signature(
parser: argparse.ArgumentParser,
fn: Callable[..., Any],
*,
overrides: dict[str, str] | None = None,
skip_fields: dict[str, set[str]] | None = None,
aliases: dict[str, list[str]] | None = None,
help_overrides: dict[str, str] | None = None,
) -> None:
overrides = overrides or {}
aliases = aliases or {}
help_overrides = help_overrides or {}
skip_fields = skip_fields or {}
top_level_skip = skip_fields.get("__top_level__", set())

sig = inspect.signature(fn)
help = _parse_args_section_from_doc(inspect.getdoc(fn) or "")
hints = typing.get_type_hints(fn, include_extras=True)

for name, param in sig.parameters.items():
resolved_type = unwrap_union(hints.get(name))

# Allow one nested layer of TypeDicts
if resolved_type and typing_extensions.is_typeddict(resolved_type):
field_help = typed_dict_field_docs(resolved_type)
field_hints = typing.get_type_hints(resolved_type, include_extras=True)
field_skip = skip_fields.get(name, set())

for field_name, field_type in resolved_type.__annotations__.items():
if field_name in field_skip:
continue
flag_name = "--" + field_name.replace("_", "-")
flags = [flag_name] + aliases.get(f"{name}.{field_name}", [])
help_text = help_overrides.get(f"{name}.{field_name}", field_help.get(field_name))

_add_flag(parser, flags, field_hints.get(field_name, field_type), help_text)
continue

if name in top_level_skip:
continue

flag_name = "--" + name.replace("_", "-")
flags = [flag_name] + aliases.get(name, [])
help_text = help_overrides.get(name, help.get(name))

_add_flag(parser, flags, hints.get(name), help_text)
Loading
Loading