Skip to content

Commit 63d1755

Browse files
committed
udpate
1 parent ae6cbc4 commit 63d1755

File tree

4 files changed

+59
-27
lines changed

4 files changed

+59
-27
lines changed

eval_protocol/cli.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -429,7 +429,7 @@ def parse_args(args=None):
429429
rft_parser.add_argument("--wandb-run-id")
430430
rft_parser.add_argument("--wandb-api-key")
431431
# Misc
432-
rft_parser.add_argument("--rft-job-id", help="Specify an explicit RFT job id")
432+
rft_parser.add_argument("--job-id", help="Specify an explicit RFT job id")
433433
rft_parser.add_argument("--yes", "-y", action="store_true", help="Non-interactive mode")
434434
rft_parser.add_argument("--dry-run", action="store_true", help="Print planned REST calls without sending")
435435
rft_parser.add_argument("--force", action="store_true", help="Overwrite existing evaluator with the same ID")

eval_protocol/cli_commands/create_rft.py

Lines changed: 24 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -344,8 +344,7 @@ def _poll_evaluator_status(
344344

345345

346346
def create_rft_command(args) -> int:
347-
# Accept both new and legacy names for backwards-compat
348-
evaluator_id: Optional[str] = getattr(args, "evaluator", None) or getattr(args, "evaluator_id", None)
347+
evaluator_id: Optional[str] = getattr(args, "evaluator", None)
349348
non_interactive: bool = bool(getattr(args, "yes", False))
350349
dry_run: bool = bool(getattr(args, "dry_run", False))
351350
force: bool = bool(getattr(args, "force", False))
@@ -374,19 +373,19 @@ def create_rft_command(args) -> int:
374373
print("No evaluation tests found.")
375374
print("\nHint: Make sure your tests use the @evaluation_test decorator.")
376375
return 1
377-
# Always interactive selection here (no implicit quiet unless --evaluator-id was provided)
376+
# Always interactive selection here
378377
try:
379378
selected_tests = _prompt_select(tests, non_interactive=non_interactive)
380379
except Exception:
381-
print("Error: Failed to open selector UI. Please pass --evaluator-id or --entry explicitly.")
380+
print("Error: Failed to open selector UI. Please pass --evaluator or --entry explicitly.")
382381
return 1
383382
if not selected_tests:
384383
print("No tests selected.")
385384
return 1
386385
if len(selected_tests) != 1:
387386
if non_interactive and len(selected_tests) > 1:
388387
print("Error: Multiple evaluation tests found in --yes (non-interactive) mode.")
389-
print(" Please pass --evaluator-id or --entry to disambiguate.")
388+
print(" Please pass --evaluator or --entry to disambiguate.")
390389
try:
391390
# Offer candidate evaluator ids for convenience
392391
tests = _discover_tests(project_root)
@@ -476,10 +475,10 @@ def create_rft_command(args) -> int:
476475
# If still unresolved and multiple tests exist, fail fast to avoid uploading unintended evaluators
477476
if selected_entry is None and len(tests) > 1:
478477
print(
479-
f"Error: Multiple evaluation tests found, and the selected evaluator_id {evaluator_id} does not match any discovered test.\n"
480-
" Please re-run specifying the evaluator id.\n"
478+
f"Error: Multiple evaluation tests found, and the selected evaluator {evaluator_id} does not match any discovered test.\n"
479+
" Please re-run specifying the evaluator.\n"
481480
" Hints:\n"
482-
" - eval-protocol create rft --evaluator-id <existing-evaluator-id>\n"
481+
" - eval-protocol create rft --evaluator <existing-evaluator-id>\n"
483482
)
484483
return 1
485484

@@ -529,8 +528,7 @@ def create_rft_command(args) -> int:
529528
print(f"Warning: Failed to upload evaluator automatically: {e}")
530529

531530
# Determine dataset id and materialization path
532-
# Accept both new and legacy names for dataset
533-
dataset_id = getattr(args, "dataset", None) or getattr(args, "dataset_id", None)
531+
dataset_id = getattr(args, "dataset", None)
534532
dataset_jsonl = getattr(args, "dataset_jsonl", None)
535533
dataset_display_name = getattr(args, "dataset_display_name", None)
536534
dataset_builder = getattr(args, "dataset_builder", None) # accepted but unused in simplified flow
@@ -585,7 +583,7 @@ def create_rft_command(args) -> int:
585583
print(f"Warning: dataset builder failed: {e}")
586584
if not dataset_jsonl:
587585
print(
588-
"Error: Could not determine dataset. Provide --dataset-id or --dataset-jsonl, or ensure a JSONL-based data loader or input_dataset is used in your single discovered test."
586+
"Error: Could not determine dataset. Provide --dataset or --dataset-jsonl, or ensure a JSONL-based data loader or input_dataset is used in your single discovered test."
589587
)
590588
return 1
591589

@@ -640,7 +638,7 @@ def create_rft_command(args) -> int:
640638
("learningRate", "learning_rate"),
641639
("maxContextLength", "max_context_length"),
642640
("loraRank", "lora_rank"),
643-
("gradAccumulationSteps", "gradient_accumulation_steps"),
641+
("gradientAccumulationSteps", "gradient_accumulation_steps"),
644642
("learningRateWarmupSteps", "learning_rate_warmup_steps"),
645643
("acceleratorCount", "accelerator_count"),
646644
("region", "region"),
@@ -655,13 +653,24 @@ def create_rft_command(args) -> int:
655653
("topP", "top_p"),
656654
("topK", "top_k"),
657655
("maxTokens", "max_output_tokens"),
658-
("responseCandidatesCount", "response_candidates_count"),
656+
("n", "response_candidates_count"),
659657
]:
660658
val = getattr(args, arg_name, None)
661659
if val is not None:
662660
inference_params[key] = val
663661
if getattr(args, "extra_body", None):
664-
inference_params["extraBody"] = args.extra_body
662+
extra = getattr(args, "extra_body")
663+
if isinstance(extra, (dict, list)):
664+
try:
665+
inference_params["extraBody"] = json.dumps(extra, ensure_ascii=False)
666+
except (TypeError, ValueError) as e:
667+
print(f"Error: --extra-body dict/list must be JSON-serializable: {e}")
668+
return 1
669+
elif isinstance(extra, str):
670+
inference_params["extraBody"] = extra
671+
else:
672+
print("Error: --extra-body must be a JSON string or a JSON-serializable dict/list.")
673+
return 1
665674

666675
wandb_config: Optional[Dict[str, Any]] = None
667676
if getattr(args, "wandb_enabled", False):
@@ -688,7 +697,7 @@ def create_rft_command(args) -> int:
688697
"outputStats": None,
689698
"outputMetrics": None,
690699
"mcpServer": getattr(args, "mcp_server", None),
691-
"jobId": getattr(args, "rft_job_id", None),
700+
"jobId": getattr(args, "job_id", None),
692701
}
693702
# Debug: print minimal summary
694703
print(f"Prepared RFT job for evaluator '{evaluator_id}' using dataset '{dataset_id}'")

eval_protocol/fireworks_rft.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import uuid
99
from pathlib import Path
1010
from typing import Any, Callable, Dict, Iterable, Optional, Tuple
11+
from urllib.parse import urlencode
1112

1213
import requests
1314

@@ -186,6 +187,14 @@ def create_reinforcement_fine_tuning_job(
186187
body: Dict[str, Any],
187188
) -> Dict[str, Any]:
188189
url = f"{api_base.rstrip('/')}/v1/accounts/{account_id}/reinforcementFineTuningJobs"
190+
# Move optional jobId from body to query parameter if provided
191+
job_id = body.get("jobId")
192+
if isinstance(job_id, str):
193+
job_id = job_id.strip()
194+
if job_id:
195+
# Remove from body and append as query param
196+
body.pop("jobId", None)
197+
url = f"{url}?{urlencode({'reinforcementFineTuningJobId': job_id})}"
189198
headers = {
190199
"Authorization": f"Bearer {api_key}",
191200
"Content-Type": "application/json",

tests/test_cli_create_rft_infer.py

Lines changed: 25 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,7 @@ def _fake_create_job(account_id, api_key, api_base, body):
105105
wandb_run_id="run123",
106106
wandb_api_key="key123",
107107
# Unused in body but accepted by parser
108-
rft_job_id=None,
108+
job_id=None,
109109
display_name=None,
110110
)
111111

@@ -139,7 +139,7 @@ def _fake_create_job(account_id, api_key, api_base, body):
139139
assert abs(ip["topP"] - 0.95) < 1e-12
140140
assert ip["topK"] == 50
141141
assert ip["maxTokens"] == 4096
142-
assert ip["responseCandidatesCount"] == 6
142+
assert ip["n"] == 6
143143
assert ip["extraBody"] == '{"foo":"bar"}'
144144

145145
# W&B mapping
@@ -866,14 +866,23 @@ def raise_for_status(self):
866866

867867
monkeypatch.setattr(cr.requests, "get", lambda *a, **k: _Resp())
868868

869-
# Capture body
870-
captured = {"body": None}
869+
# Capture URL and JSON via fireworks layer
870+
import eval_protocol.fireworks_rft as fr
871871

872-
def _fake_create_job(account_id, api_key, api_base, body):
873-
captured["body"] = body
874-
return {"name": f"accounts/{account_id}/reinforcementFineTuningJobs/xyz"}
872+
captured = {"url": None, "json": None}
875873

876-
monkeypatch.setattr(cr, "create_reinforcement_fine_tuning_job", _fake_create_job)
874+
class _RespPost:
875+
status_code = 200
876+
877+
def json(self):
878+
return {"name": "accounts/pyroworks-dev/reinforcementFineTuningJobs/xyz"}
879+
880+
def _fake_post(url, json=None, headers=None, timeout=None):
881+
captured["url"] = url
882+
captured["json"] = json
883+
return _RespPost()
884+
885+
monkeypatch.setattr(fr.requests, "post", _fake_post)
877886

878887
# Build args via CLI parser to validate flag names
879888
from eval_protocol.cli import parse_args
@@ -905,15 +914,17 @@ def _fake_create_job(account_id, api_key, api_base, body):
905914
"0.00003",
906915
"--lora-rank",
907916
"16",
917+
"--job-id",
918+
"custom-job-123",
908919
"--yes",
909920
]
910921
args, _ = parse_args(argv)
911922

912923
# Execute command
913924
rc = cr.create_rft_command(args)
914925
assert rc == 0
915-
assert captured["body"] is not None
916-
body = captured["body"]
926+
assert captured["json"] is not None
927+
body = captured["json"]
917928

918929
# Evaluator and dataset resources
919930
assert body["evaluator"] == "accounts/pyroworks-dev/evaluators/test-livesvgbench-test-svg-combined-evaluation1"
@@ -931,8 +942,11 @@ def _fake_create_job(account_id, api_key, api_base, body):
931942

932943
# Inference params mapping
933944
ip = body["inferenceParameters"]
934-
assert ip["responseCandidatesCount"] == 4
945+
assert ip["n"] == 4
935946
assert ip["maxTokens"] == 32768
936947

937948
# Other top-level
938949
assert body["chunkSize"] == 50
950+
# Job id sent as query param
951+
assert captured["url"] is not None and "reinforcementFineTuningJobId=custom-job-123" in captured["url"]
952+
assert "jobId" not in body

0 commit comments

Comments
 (0)