Skip to content

Commit ae6cbc4

Browse files
committed
unify flags
1 parent 16648fe commit ae6cbc4

File tree

3 files changed

+158
-41
lines changed

3 files changed

+158
-41
lines changed

eval_protocol/cli.py

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -371,13 +371,13 @@ def parse_args(args=None):
371371
help="Create a Reinforcement Fine-tuning Job on Fireworks",
372372
)
373373
rft_parser.add_argument(
374-
"--evaluator-id",
375-
help="Evaluator ID used during upload; if omitted, derive from local traces or a single discovered test",
374+
"--evaluator",
375+
help="Evaluator ID or fully-qualified resource (accounts/{acct}/evaluators/{id}); if omitted, derive from local tests",
376376
)
377377
# Dataset options
378378
rft_parser.add_argument(
379-
"--dataset-id",
380-
help="Use existing Fireworks dataset id (skip local materialization)",
379+
"--dataset",
380+
help="Use existing dataset (ID or resource 'accounts/{acct}/datasets/{id}') to skip local materialization",
381381
)
382382
rft_parser.add_argument(
383383
"--dataset-jsonl",
@@ -400,6 +400,8 @@ def parse_args(args=None):
400400
rft_parser.add_argument("--learning-rate", type=float, default=3e-5)
401401
rft_parser.add_argument("--max-context-length", type=int, default=65536)
402402
rft_parser.add_argument("--lora-rank", type=int, default=16)
403+
rft_parser.add_argument("--gradient-accumulation-steps", type=int, help="Number of gradient accumulation steps")
404+
rft_parser.add_argument("--learning-rate-warmup-steps", type=int, help="Number of LR warmup steps")
403405
rft_parser.add_argument("--accelerator-count", type=int, default=1)
404406
rft_parser.add_argument("--region", help="Fireworks region enum value")
405407
rft_parser.add_argument("--display-name", help="RFT job display name")
@@ -412,9 +414,14 @@ def parse_args(args=None):
412414
rft_parser.add_argument("--temperature", type=float)
413415
rft_parser.add_argument("--top-p", type=float)
414416
rft_parser.add_argument("--top-k", type=int)
415-
rft_parser.add_argument("--max-tokens", type=int, default=32768)
416-
rft_parser.add_argument("--n", type=int, default=8)
417-
rft_parser.add_argument("--inference-extra-body", help="JSON string for extra inference params")
417+
rft_parser.add_argument("--max-output-tokens", type=int, default=32768)
418+
rft_parser.add_argument("--response-candidates-count", type=int, default=8)
419+
rft_parser.add_argument("--extra-body", help="JSON string for extra inference params")
420+
# MCP server (optional)
421+
rft_parser.add_argument(
422+
"--mcp-server",
423+
help="The MCP server resource name to use for the reinforcement fine-tuning job.",
424+
)
418425
# Wandb
419426
rft_parser.add_argument("--wandb-enabled", action="store_true")
420427
rft_parser.add_argument("--wandb-project")

eval_protocol/cli_commands/create_rft.py

Lines changed: 29 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -344,7 +344,8 @@ def _poll_evaluator_status(
344344

345345

346346
def create_rft_command(args) -> int:
347-
evaluator_id: Optional[str] = getattr(args, "evaluator_id", None)
347+
# Accept both new and legacy names for backwards-compat
348+
evaluator_id: Optional[str] = getattr(args, "evaluator", None) or getattr(args, "evaluator_id", None)
348349
non_interactive: bool = bool(getattr(args, "yes", False))
349350
dry_run: bool = bool(getattr(args, "dry_run", False))
350351
force: bool = bool(getattr(args, "force", False))
@@ -410,8 +411,13 @@ def create_rft_command(args) -> int:
410411
selected_test_file_path, selected_test_func_name = _resolve_selected_test(
411412
project_root, evaluator_id, selected_tests=selected_tests
412413
)
413-
# Resolve evaluator resource name to fully-qualified format required by API
414-
evaluator_resource_name = f"accounts/{account_id}/evaluators/{evaluator_id}"
414+
# Resolve evaluator resource name to fully-qualified format required by API.
415+
# Allow users to pass either short id or fully-qualified resource.
416+
if evaluator_id and evaluator_id.startswith("accounts/"):
417+
evaluator_resource_name = evaluator_id
418+
evaluator_id = _extract_terminal_segment(evaluator_id)
419+
else:
420+
evaluator_resource_name = f"accounts/{account_id}/evaluators/{evaluator_id}"
415421

416422
# Optional short-circuit: if evaluator already exists and not forcing, skip upload path
417423
skip_upload = False
@@ -523,10 +529,16 @@ def create_rft_command(args) -> int:
523529
print(f"Warning: Failed to upload evaluator automatically: {e}")
524530

525531
# Determine dataset id and materialization path
526-
dataset_id = getattr(args, "dataset_id", None)
532+
# Accept both new and legacy names for dataset
533+
dataset_id = getattr(args, "dataset", None) or getattr(args, "dataset_id", None)
527534
dataset_jsonl = getattr(args, "dataset_jsonl", None)
528535
dataset_display_name = getattr(args, "dataset_display_name", None)
529536
dataset_builder = getattr(args, "dataset_builder", None) # accepted but unused in simplified flow
537+
dataset_resource_override: Optional[str] = None
538+
if isinstance(dataset_id, str) and dataset_id.startswith("accounts/"):
539+
# Caller passed a fully-qualified dataset; capture it for body and keep only terminal id for printing
540+
dataset_resource_override = dataset_id
541+
dataset_id = _extract_terminal_segment(dataset_id)
530542

531543
if not dataset_id:
532544
# Prefer explicit --dataset-jsonl, else attempt to extract from the selected test's data loader or input_dataset.
@@ -628,6 +640,8 @@ def create_rft_command(args) -> int:
628640
("learningRate", "learning_rate"),
629641
("maxContextLength", "max_context_length"),
630642
("loraRank", "lora_rank"),
643+
("gradAccumulationSteps", "gradient_accumulation_steps"),
644+
("learningRateWarmupSteps", "learning_rate_warmup_steps"),
631645
("acceleratorCount", "accelerator_count"),
632646
("region", "region"),
633647
]:
@@ -640,14 +654,14 @@ def create_rft_command(args) -> int:
640654
("temperature", "temperature"),
641655
("topP", "top_p"),
642656
("topK", "top_k"),
643-
("maxTokens", "max_tokens"),
644-
("n", "n"),
657+
("maxTokens", "max_output_tokens"),
658+
("responseCandidatesCount", "response_candidates_count"),
645659
]:
646660
val = getattr(args, arg_name, None)
647661
if val is not None:
648662
inference_params[key] = val
649-
if getattr(args, "inference_extra_body", None):
650-
inference_params["extraBody"] = args.inference_extra_body
663+
if getattr(args, "extra_body", None):
664+
inference_params["extraBody"] = args.extra_body
651665

652666
wandb_config: Optional[Dict[str, Any]] = None
653667
if getattr(args, "wandb_enabled", False):
@@ -659,9 +673,12 @@ def create_rft_command(args) -> int:
659673
"runId": getattr(args, "wandb_run_id", None),
660674
}
661675

676+
# Build dataset resource (prefer override when provided)
677+
dataset_resource = dataset_resource_override or f"accounts/{account_id}/datasets/{dataset_id}"
678+
662679
body: Dict[str, Any] = {
663-
# "displayName": getattr(args, "display_name", None) or f"{evaluator_id}-rft",
664-
"dataset": f"accounts/{account_id}/datasets/{dataset_id}",
680+
"displayName": getattr(args, "display_name", None),
681+
"dataset": dataset_resource,
665682
"evaluator": evaluator_resource_name,
666683
"evalAutoCarveout": bool(getattr(args, "eval_auto_carveout", True)),
667684
"trainingConfig": training_config,
@@ -670,7 +687,8 @@ def create_rft_command(args) -> int:
670687
"chunkSize": getattr(args, "chunk_size", None),
671688
"outputStats": None,
672689
"outputMetrics": None,
673-
"mcpServer": None,
690+
"mcpServer": getattr(args, "mcp_server", None),
691+
"jobId": getattr(args, "rft_job_id", None),
674692
}
675693
# Debug: print minimal summary
676694
print(f"Prepared RFT job for evaluator '{evaluator_id}' using dataset '{dataset_id}'")

tests/test_cli_create_rft_infer.py

Lines changed: 115 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -65,8 +65,8 @@ def _fake_create_job(account_id, api_key, api_base, body):
6565

6666
args = argparse.Namespace(
6767
# Evaluator and dataset
68-
evaluator_id="my-evaluator",
69-
dataset_id=None,
68+
evaluator="my-evaluator",
69+
dataset=None,
7070
dataset_jsonl=str(ds_path),
7171
dataset_display_name="My Dataset",
7272
dataset_builder=None,
@@ -91,9 +91,9 @@ def _fake_create_job(account_id, api_key, api_base, body):
9191
temperature=0.9,
9292
top_p=0.95,
9393
top_k=50,
94-
max_tokens=4096,
95-
n=6,
96-
inference_extra_body='{"foo":"bar"}',
94+
max_output_tokens=4096,
95+
response_candidates_count=6,
96+
extra_body='{"foo":"bar"}',
9797
# Rollout chunking and eval carveout
9898
chunk_size=250,
9999
eval_auto_carveout=False, # explicitly disabled via --no-eval-auto-carveout
@@ -139,7 +139,7 @@ def _fake_create_job(account_id, api_key, api_base, body):
139139
assert abs(ip["topP"] - 0.95) < 1e-12
140140
assert ip["topK"] == 50
141141
assert ip["maxTokens"] == 4096
142-
assert ip["n"] == 6
142+
assert ip["responseCandidatesCount"] == 6
143143
assert ip["extraBody"] == '{"foo":"bar"}'
144144

145145
# W&B mapping
@@ -195,12 +195,12 @@ def _fake_create_dataset_from_jsonl(account_id, api_key, api_base, dataset_id, d
195195

196196
# Build args: non_interactive (yes=True), no explicit evaluator_id, valid warm_start_from
197197
args = type("Args", (), {})()
198-
setattr(args, "evaluator_id", None)
198+
setattr(args, "evaluator", None)
199199
setattr(args, "yes", True)
200200
setattr(args, "dry_run", False)
201201
setattr(args, "force", False)
202202
setattr(args, "env_file", None)
203-
setattr(args, "dataset_id", None)
203+
setattr(args, "dataset", None)
204204
setattr(args, "dataset_jsonl", str(ds_path))
205205
setattr(args, "dataset_display_name", None)
206206
setattr(args, "dataset_builder", None)
@@ -283,12 +283,12 @@ def _fake_create_dataset_from_jsonl(account_id, api_key, api_base, dataset_id, d
283283
import argparse
284284

285285
args = argparse.Namespace(
286-
evaluator_id=cr._normalize_evaluator_id("foo_eval-test_bar_evaluation"),
286+
evaluator=cr._normalize_evaluator_id("foo_eval-test_bar_evaluation"),
287287
yes=True,
288288
dry_run=False,
289289
force=False,
290290
env_file=None,
291-
dataset_id=None,
291+
dataset=None,
292292
dataset_jsonl=str(ds_path),
293293
dataset_display_name=None,
294294
dataset_builder=None,
@@ -371,12 +371,12 @@ def _fake_upload(ns):
371371
import argparse
372372

373373
args = argparse.Namespace(
374-
evaluator_id=None,
374+
evaluator=None,
375375
yes=True,
376376
dry_run=False,
377377
force=False,
378378
env_file=None,
379-
dataset_id=None,
379+
dataset=None,
380380
dataset_jsonl=str(ds_path),
381381
dataset_display_name=None,
382382
dataset_builder=None,
@@ -438,12 +438,12 @@ def raise_for_status(self):
438438
import argparse
439439

440440
args = argparse.Namespace(
441-
evaluator_id="some-eval",
441+
evaluator="some-eval",
442442
yes=True,
443443
dry_run=False,
444444
force=False,
445445
env_file=None,
446-
dataset_id=None,
446+
dataset=None,
447447
dataset_jsonl=str(ds_path),
448448
dataset_display_name=None,
449449
dataset_builder=None,
@@ -495,12 +495,12 @@ def _raise(*a, **k):
495495
import argparse
496496

497497
args = argparse.Namespace(
498-
evaluator_id="some-eval",
498+
evaluator="some-eval",
499499
yes=True,
500500
dry_run=False,
501501
force=False,
502502
env_file=None,
503-
dataset_id=None,
503+
dataset=None,
504504
dataset_jsonl=str(project / "dataset.jsonl"),
505505
dataset_display_name=None,
506506
dataset_builder=None,
@@ -571,12 +571,12 @@ def _fake_create_dataset_from_jsonl(account_id, api_key, api_base, dataset_id, d
571571
import argparse
572572

573573
args = argparse.Namespace(
574-
evaluator_id=None,
574+
evaluator=None,
575575
yes=True,
576576
dry_run=False,
577577
force=False,
578578
env_file=None,
579-
dataset_id=None,
579+
dataset=None,
580580
dataset_jsonl=None,
581581
dataset_display_name=None,
582582
dataset_builder=None,
@@ -648,12 +648,12 @@ def _fake_create_dataset_from_jsonl(account_id, api_key, api_base, dataset_id, d
648648
import argparse
649649

650650
args = argparse.Namespace(
651-
evaluator_id=None,
651+
evaluator=None,
652652
yes=True,
653653
dry_run=False,
654654
force=False,
655655
env_file=None,
656-
dataset_id=None,
656+
dataset=None,
657657
dataset_jsonl=None,
658658
dataset_display_name=None,
659659
dataset_builder=None,
@@ -728,7 +728,7 @@ def _fake_create_dataset_from_jsonl(account_id, api_key, api_base, dataset_id, d
728728
dry_run=False,
729729
force=False,
730730
env_file=None,
731-
dataset_id=None,
731+
dataset=None,
732732
dataset_jsonl=None,
733733
dataset_display_name=None,
734734
dataset_builder=None,
@@ -815,12 +815,12 @@ def _fake_create_dataset_from_jsonl(account_id, api_key, api_base, dataset_id, d
815815
# Provide evaluator_id that matches beta.test_two
816816
eval_id = cr._normalize_evaluator_id("beta-test_two")
817817
args = argparse.Namespace(
818-
evaluator_id=eval_id,
818+
evaluator=eval_id,
819819
yes=True,
820820
dry_run=False,
821821
force=False,
822822
env_file=None,
823-
dataset_id=None,
823+
dataset=None,
824824
dataset_jsonl=None,
825825
dataset_display_name=None,
826826
dataset_builder=None,
@@ -844,3 +844,95 @@ def _fake_create_dataset_from_jsonl(account_id, api_key, api_base, dataset_id, d
844844
# Ensure the dataset id is based on evaluator_id
845845
assert captured["dataset_id"].startswith(f"{eval_id}-dataset-")
846846
assert captured["jsonl_path"] == str(jsonl_path)
847+
848+
849+
def test_cli_full_command_style_evaluator_and_dataset_flags(monkeypatch):
850+
# Env
851+
monkeypatch.setenv("FIREWORKS_API_KEY", "fw_dummy")
852+
monkeypatch.setenv("FIREWORKS_ACCOUNT_ID", "pyroworks-dev")
853+
monkeypatch.setenv("FIREWORKS_API_BASE", "https://api.fireworks.ai")
854+
855+
# Mock evaluator exists and ACTIVE
856+
class _Resp:
857+
ok = True
858+
859+
def json(self):
860+
return {"state": "ACTIVE"}
861+
862+
def raise_for_status(self):
863+
return None
864+
865+
from eval_protocol.cli_commands import create_rft as cr
866+
867+
monkeypatch.setattr(cr.requests, "get", lambda *a, **k: _Resp())
868+
869+
# Capture body
870+
captured = {"body": None}
871+
872+
def _fake_create_job(account_id, api_key, api_base, body):
873+
captured["body"] = body
874+
return {"name": f"accounts/{account_id}/reinforcementFineTuningJobs/xyz"}
875+
876+
monkeypatch.setattr(cr, "create_reinforcement_fine_tuning_job", _fake_create_job)
877+
878+
# Build args via CLI parser to validate flag names
879+
from eval_protocol.cli import parse_args
880+
881+
argv = [
882+
"create",
883+
"rft",
884+
"--base-model",
885+
"accounts/fireworks/models/qwen3-0p6b",
886+
"--dataset",
887+
"svgbench-small",
888+
"--output-model",
889+
"svgbench-agent-small-bchen-2",
890+
"--evaluator",
891+
"accounts/pyroworks-dev/evaluators/test-livesvgbench-test-svg-combined-evaluation1",
892+
"--max-context-length",
893+
"65536",
894+
"--response-candidates-count",
895+
"4",
896+
"--batch-size",
897+
"128000",
898+
"--chunk-size",
899+
"50",
900+
"--epochs",
901+
"4",
902+
"--max-output-tokens",
903+
"32768",
904+
"--learning-rate",
905+
"0.00003",
906+
"--lora-rank",
907+
"16",
908+
"--yes",
909+
]
910+
args, _ = parse_args(argv)
911+
912+
# Execute command
913+
rc = cr.create_rft_command(args)
914+
assert rc == 0
915+
assert captured["body"] is not None
916+
body = captured["body"]
917+
918+
# Evaluator and dataset resources
919+
assert body["evaluator"] == "accounts/pyroworks-dev/evaluators/test-livesvgbench-test-svg-combined-evaluation1"
920+
assert body["dataset"] == "accounts/pyroworks-dev/datasets/svgbench-small"
921+
922+
# Training config mapping
923+
tc = body["trainingConfig"]
924+
assert tc["baseModel"] == "accounts/fireworks/models/qwen3-0p6b"
925+
assert tc["outputModel"] == "accounts/pyroworks-dev/models/svgbench-agent-small-bchen-2"
926+
assert tc["epochs"] == 4
927+
assert tc["batchSize"] == 128000
928+
assert abs(tc["learningRate"] - 0.00003) < 1e-12
929+
assert tc["loraRank"] == 16
930+
assert tc["maxContextLength"] == 65536
931+
932+
# Inference params mapping
933+
ip = body["inferenceParameters"]
934+
assert ip["responseCandidatesCount"] == 4
935+
assert ip["maxTokens"] == 32768
936+
937+
# Other top-level
938+
assert body["chunkSize"] == 50

0 commit comments

Comments
 (0)