Skip to content

Commit bd27ffd

Browse files
authored
unify flags (#328)
* unify flags * udpate * update test
1 parent 16648fe commit bd27ffd

File tree

4 files changed

+285
-50
lines changed

4 files changed

+285
-50
lines changed

eval_protocol/cli.py

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -371,13 +371,13 @@ def parse_args(args=None):
371371
help="Create a Reinforcement Fine-tuning Job on Fireworks",
372372
)
373373
rft_parser.add_argument(
374-
"--evaluator-id",
375-
help="Evaluator ID used during upload; if omitted, derive from local traces or a single discovered test",
374+
"--evaluator",
375+
help="Evaluator ID or fully-qualified resource (accounts/{acct}/evaluators/{id}); if omitted, derive from local tests",
376376
)
377377
# Dataset options
378378
rft_parser.add_argument(
379-
"--dataset-id",
380-
help="Use existing Fireworks dataset id (skip local materialization)",
379+
"--dataset",
380+
help="Use existing dataset (ID or resource 'accounts/{acct}/datasets/{id}') to skip local materialization",
381381
)
382382
rft_parser.add_argument(
383383
"--dataset-jsonl",
@@ -400,6 +400,8 @@ def parse_args(args=None):
400400
rft_parser.add_argument("--learning-rate", type=float, default=3e-5)
401401
rft_parser.add_argument("--max-context-length", type=int, default=65536)
402402
rft_parser.add_argument("--lora-rank", type=int, default=16)
403+
rft_parser.add_argument("--gradient-accumulation-steps", type=int, help="Number of gradient accumulation steps")
404+
rft_parser.add_argument("--learning-rate-warmup-steps", type=int, help="Number of LR warmup steps")
403405
rft_parser.add_argument("--accelerator-count", type=int, default=1)
404406
rft_parser.add_argument("--region", help="Fireworks region enum value")
405407
rft_parser.add_argument("--display-name", help="RFT job display name")
@@ -412,17 +414,22 @@ def parse_args(args=None):
412414
rft_parser.add_argument("--temperature", type=float)
413415
rft_parser.add_argument("--top-p", type=float)
414416
rft_parser.add_argument("--top-k", type=int)
415-
rft_parser.add_argument("--max-tokens", type=int, default=32768)
416-
rft_parser.add_argument("--n", type=int, default=8)
417-
rft_parser.add_argument("--inference-extra-body", help="JSON string for extra inference params")
417+
rft_parser.add_argument("--max-output-tokens", type=int, default=32768)
418+
rft_parser.add_argument("--response-candidates-count", type=int, default=8)
419+
rft_parser.add_argument("--extra-body", help="JSON string for extra inference params")
420+
# MCP server (optional)
421+
rft_parser.add_argument(
422+
"--mcp-server",
423+
help="The MCP server resource name to use for the reinforcement fine-tuning job.",
424+
)
418425
# Wandb
419426
rft_parser.add_argument("--wandb-enabled", action="store_true")
420427
rft_parser.add_argument("--wandb-project")
421428
rft_parser.add_argument("--wandb-entity")
422429
rft_parser.add_argument("--wandb-run-id")
423430
rft_parser.add_argument("--wandb-api-key")
424431
# Misc
425-
rft_parser.add_argument("--rft-job-id", help="Specify an explicit RFT job id")
432+
rft_parser.add_argument("--job-id", help="Specify an explicit RFT job id")
426433
rft_parser.add_argument("--yes", "-y", action="store_true", help="Non-interactive mode")
427434
rft_parser.add_argument("--dry-run", action="store_true", help="Print planned REST calls without sending")
428435
rft_parser.add_argument("--force", action="store_true", help="Overwrite existing evaluator with the same ID")

eval_protocol/cli_commands/create_rft.py

Lines changed: 45 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -344,7 +344,7 @@ def _poll_evaluator_status(
344344

345345

346346
def create_rft_command(args) -> int:
347-
evaluator_id: Optional[str] = getattr(args, "evaluator_id", None)
347+
evaluator_id: Optional[str] = getattr(args, "evaluator", None)
348348
non_interactive: bool = bool(getattr(args, "yes", False))
349349
dry_run: bool = bool(getattr(args, "dry_run", False))
350350
force: bool = bool(getattr(args, "force", False))
@@ -373,19 +373,19 @@ def create_rft_command(args) -> int:
373373
print("No evaluation tests found.")
374374
print("\nHint: Make sure your tests use the @evaluation_test decorator.")
375375
return 1
376-
# Always interactive selection here (no implicit quiet unless --evaluator-id was provided)
376+
# Always interactive selection here
377377
try:
378378
selected_tests = _prompt_select(tests, non_interactive=non_interactive)
379379
except Exception:
380-
print("Error: Failed to open selector UI. Please pass --evaluator-id or --entry explicitly.")
380+
print("Error: Failed to open selector UI. Please pass --evaluator or --entry explicitly.")
381381
return 1
382382
if not selected_tests:
383383
print("No tests selected.")
384384
return 1
385385
if len(selected_tests) != 1:
386386
if non_interactive and len(selected_tests) > 1:
387387
print("Error: Multiple evaluation tests found in --yes (non-interactive) mode.")
388-
print(" Please pass --evaluator-id or --entry to disambiguate.")
388+
print(" Please pass --evaluator or --entry to disambiguate.")
389389
try:
390390
# Offer candidate evaluator ids for convenience
391391
tests = _discover_tests(project_root)
@@ -410,8 +410,13 @@ def create_rft_command(args) -> int:
410410
selected_test_file_path, selected_test_func_name = _resolve_selected_test(
411411
project_root, evaluator_id, selected_tests=selected_tests
412412
)
413-
# Resolve evaluator resource name to fully-qualified format required by API
414-
evaluator_resource_name = f"accounts/{account_id}/evaluators/{evaluator_id}"
413+
# Resolve evaluator resource name to fully-qualified format required by API.
414+
# Allow users to pass either short id or fully-qualified resource.
415+
if evaluator_id and evaluator_id.startswith("accounts/"):
416+
evaluator_resource_name = evaluator_id
417+
evaluator_id = _extract_terminal_segment(evaluator_id)
418+
else:
419+
evaluator_resource_name = f"accounts/{account_id}/evaluators/{evaluator_id}"
415420

416421
# Optional short-circuit: if evaluator already exists and not forcing, skip upload path
417422
skip_upload = False
@@ -470,10 +475,10 @@ def create_rft_command(args) -> int:
470475
# If still unresolved and multiple tests exist, fail fast to avoid uploading unintended evaluators
471476
if selected_entry is None and len(tests) > 1:
472477
print(
473-
f"Error: Multiple evaluation tests found, and the selected evaluator_id {evaluator_id} does not match any discovered test.\n"
474-
" Please re-run specifying the evaluator id.\n"
478+
f"Error: Multiple evaluation tests found, and the selected evaluator {evaluator_id} does not match any discovered test.\n"
479+
" Please re-run specifying the evaluator.\n"
475480
" Hints:\n"
476-
" - eval-protocol create rft --evaluator-id <existing-evaluator-id>\n"
481+
" - eval-protocol create rft --evaluator <existing-evaluator-id>\n"
477482
)
478483
return 1
479484

@@ -523,10 +528,15 @@ def create_rft_command(args) -> int:
523528
print(f"Warning: Failed to upload evaluator automatically: {e}")
524529

525530
# Determine dataset id and materialization path
526-
dataset_id = getattr(args, "dataset_id", None)
531+
dataset_id = getattr(args, "dataset", None)
527532
dataset_jsonl = getattr(args, "dataset_jsonl", None)
528533
dataset_display_name = getattr(args, "dataset_display_name", None)
529534
dataset_builder = getattr(args, "dataset_builder", None) # accepted but unused in simplified flow
535+
dataset_resource_override: Optional[str] = None
536+
if isinstance(dataset_id, str) and dataset_id.startswith("accounts/"):
537+
# Caller passed a fully-qualified dataset; capture it for body and keep only terminal id for printing
538+
dataset_resource_override = dataset_id
539+
dataset_id = _extract_terminal_segment(dataset_id)
530540

531541
if not dataset_id:
532542
# Prefer explicit --dataset-jsonl, else attempt to extract from the selected test's data loader or input_dataset.
@@ -573,7 +583,7 @@ def create_rft_command(args) -> int:
573583
print(f"Warning: dataset builder failed: {e}")
574584
if not dataset_jsonl:
575585
print(
576-
"Error: Could not determine dataset. Provide --dataset-id or --dataset-jsonl, or ensure a JSONL-based data loader or input_dataset is used in your single discovered test."
586+
"Error: Could not determine dataset. Provide --dataset or --dataset-jsonl, or ensure a JSONL-based data loader or input_dataset is used in your single discovered test."
577587
)
578588
return 1
579589

@@ -628,6 +638,8 @@ def create_rft_command(args) -> int:
628638
("learningRate", "learning_rate"),
629639
("maxContextLength", "max_context_length"),
630640
("loraRank", "lora_rank"),
641+
("gradientAccumulationSteps", "gradient_accumulation_steps"),
642+
("learningRateWarmupSteps", "learning_rate_warmup_steps"),
631643
("acceleratorCount", "accelerator_count"),
632644
("region", "region"),
633645
]:
@@ -640,14 +652,25 @@ def create_rft_command(args) -> int:
640652
("temperature", "temperature"),
641653
("topP", "top_p"),
642654
("topK", "top_k"),
643-
("maxTokens", "max_tokens"),
644-
("n", "n"),
655+
("maxTokens", "max_output_tokens"),
656+
("n", "response_candidates_count"),
645657
]:
646658
val = getattr(args, arg_name, None)
647659
if val is not None:
648660
inference_params[key] = val
649-
if getattr(args, "inference_extra_body", None):
650-
inference_params["extraBody"] = args.inference_extra_body
661+
if getattr(args, "extra_body", None):
662+
extra = getattr(args, "extra_body")
663+
if isinstance(extra, (dict, list)):
664+
try:
665+
inference_params["extraBody"] = json.dumps(extra, ensure_ascii=False)
666+
except (TypeError, ValueError) as e:
667+
print(f"Error: --extra-body dict/list must be JSON-serializable: {e}")
668+
return 1
669+
elif isinstance(extra, str):
670+
inference_params["extraBody"] = extra
671+
else:
672+
print("Error: --extra-body must be a JSON string or a JSON-serializable dict/list.")
673+
return 1
651674

652675
wandb_config: Optional[Dict[str, Any]] = None
653676
if getattr(args, "wandb_enabled", False):
@@ -659,9 +682,12 @@ def create_rft_command(args) -> int:
659682
"runId": getattr(args, "wandb_run_id", None),
660683
}
661684

685+
# Build dataset resource (prefer override when provided)
686+
dataset_resource = dataset_resource_override or f"accounts/{account_id}/datasets/{dataset_id}"
687+
662688
body: Dict[str, Any] = {
663-
# "displayName": getattr(args, "display_name", None) or f"{evaluator_id}-rft",
664-
"dataset": f"accounts/{account_id}/datasets/{dataset_id}",
689+
"displayName": getattr(args, "display_name", None),
690+
"dataset": dataset_resource,
665691
"evaluator": evaluator_resource_name,
666692
"evalAutoCarveout": bool(getattr(args, "eval_auto_carveout", True)),
667693
"trainingConfig": training_config,
@@ -670,7 +696,8 @@ def create_rft_command(args) -> int:
670696
"chunkSize": getattr(args, "chunk_size", None),
671697
"outputStats": None,
672698
"outputMetrics": None,
673-
"mcpServer": None,
699+
"mcpServer": getattr(args, "mcp_server", None),
700+
"jobId": getattr(args, "job_id", None),
674701
}
675702
# Debug: print minimal summary
676703
print(f"Prepared RFT job for evaluator '{evaluator_id}' using dataset '{dataset_id}'")

eval_protocol/fireworks_rft.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import uuid
99
from pathlib import Path
1010
from typing import Any, Callable, Dict, Iterable, Optional, Tuple
11+
from urllib.parse import urlencode
1112

1213
import requests
1314

@@ -186,6 +187,14 @@ def create_reinforcement_fine_tuning_job(
186187
body: Dict[str, Any],
187188
) -> Dict[str, Any]:
188189
url = f"{api_base.rstrip('/')}/v1/accounts/{account_id}/reinforcementFineTuningJobs"
190+
# Move optional jobId from body to query parameter if provided
191+
job_id = body.get("jobId")
192+
if isinstance(job_id, str):
193+
job_id = job_id.strip()
194+
if job_id:
195+
# Remove from body and append as query param
196+
body.pop("jobId", None)
197+
url = f"{url}?{urlencode({'reinforcementFineTuningJobId': job_id})}"
189198
headers = {
190199
"Authorization": f"Bearer {api_key}",
191200
"Content-Type": "application/json",

0 commit comments

Comments
 (0)