Skip to content

Commit 8ab1c92

Browse files
committed
update test
1 parent 63d1755 commit 8ab1c92

File tree

1 file changed

+87
-1
lines changed

1 file changed

+87
-1
lines changed

tests/test_cli_create_rft_infer.py

Lines changed: 87 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -723,7 +723,7 @@ def _fake_create_dataset_from_jsonl(account_id, api_key, api_base, dataset_id, d
723723
import argparse
724724

725725
args = argparse.Namespace(
726-
evaluator_id=None,
726+
evaluator=None,
727727
yes=True,
728728
dry_run=False,
729729
force=False,
@@ -950,3 +950,89 @@ def _fake_post(url, json=None, headers=None, timeout=None):
950950
# Job id sent as query param
951951
assert captured["url"] is not None and "reinforcementFineTuningJobId=custom-job-123" in captured["url"]
952952
assert "jobId" not in body
953+
954+
955+
def test_create_rft_prefers_explicit_dataset_jsonl_over_input_dataset(tmp_path, monkeypatch):
956+
# Setup project
957+
project = tmp_path / "proj"
958+
project.mkdir()
959+
monkeypatch.chdir(project)
960+
961+
# Environment
962+
monkeypatch.setenv("FIREWORKS_API_KEY", "fw_dummy")
963+
monkeypatch.setenv("FIREWORKS_ACCOUNT_ID", "acct123")
964+
monkeypatch.setenv("FIREWORKS_API_BASE", "https://api.fireworks.ai")
965+
966+
# Single discovered test
967+
test_file = project / "metric" / "test_pref.py"
968+
test_file.parent.mkdir(parents=True, exist_ok=True)
969+
test_file.write_text("# prefer explicit dataset_jsonl", encoding="utf-8")
970+
single_disc = SimpleNamespace(qualname="metric.test_pref", file_path=str(test_file))
971+
monkeypatch.setattr(cr, "_discover_tests", lambda cwd: [single_disc])
972+
973+
# Stub selector, upload, and polling
974+
import eval_protocol.cli_commands.upload as upload_mod
975+
976+
monkeypatch.setattr(upload_mod, "_prompt_select", lambda tests, non_interactive=False: tests[:1])
977+
monkeypatch.setattr(upload_mod, "upload_command", lambda args: 0)
978+
monkeypatch.setattr(cr, "_poll_evaluator_status", lambda **kwargs: True)
979+
980+
# Prepare two JSONL paths: one explicit via --dataset-jsonl and one inferable via input_dataset
981+
explicit_jsonl = project / "metric" / "explicit.jsonl"
982+
explicit_jsonl.write_text('{"row":"explicit"}\n', encoding="utf-8")
983+
inferred_jsonl = project / "metric" / "inferred.jsonl"
984+
inferred_jsonl.write_text('{"row":"inferred"}\n', encoding="utf-8")
985+
986+
# If inference were to happen, return inferred path — but explicit should win
987+
monkeypatch.setattr(cr, "_extract_jsonl_from_dataloader", lambda f, fn: None)
988+
calls = {"input_dataset": 0}
989+
990+
def _extract_input_dataset(file_path, func_name):
991+
calls["input_dataset"] += 1
992+
return str(inferred_jsonl)
993+
994+
monkeypatch.setattr(cr, "_extract_jsonl_from_input_dataset", _extract_input_dataset)
995+
monkeypatch.setattr(cr, "detect_dataset_builder", lambda metric_dir: None)
996+
997+
captured = {"jsonl_path": None}
998+
999+
def _fake_create_dataset_from_jsonl(account_id, api_key, api_base, dataset_id, display_name, jsonl_path):
1000+
captured["jsonl_path"] = jsonl_path
1001+
return dataset_id, {"name": f"accounts/{account_id}/datasets/{dataset_id}", "state": "UPLOADING"}
1002+
1003+
monkeypatch.setattr(cr, "create_dataset_from_jsonl", _fake_create_dataset_from_jsonl)
1004+
monkeypatch.setattr(cr, "create_reinforcement_fine_tuning_job", lambda *a, **k: {"name": "jobs/123"})
1005+
1006+
import argparse
1007+
1008+
args = argparse.Namespace(
1009+
evaluator=None,
1010+
yes=True,
1011+
dry_run=False,
1012+
force=False,
1013+
env_file=None,
1014+
dataset=None,
1015+
dataset_jsonl=str(explicit_jsonl),
1016+
dataset_display_name=None,
1017+
dataset_builder=None,
1018+
base_model=None,
1019+
warm_start_from="accounts/acct123/models/ft-abc123",
1020+
output_model=None,
1021+
n=None,
1022+
max_tokens=None,
1023+
learning_rate=None,
1024+
batch_size=None,
1025+
epochs=None,
1026+
lora_rank=None,
1027+
max_context_length=None,
1028+
chunk_size=None,
1029+
eval_auto_carveout=None,
1030+
)
1031+
1032+
rc = cr.create_rft_command(args)
1033+
assert rc == 0
1034+
# Ensure the explicitly provided JSONL file is used, not the inferred one
1035+
assert captured["jsonl_path"] == str(explicit_jsonl)
1036+
assert captured["jsonl_path"] != str(inferred_jsonl)
1037+
# And because --dataset-jsonl was provided, we should never call the input_dataset extractor
1038+
assert calls["input_dataset"] == 0

0 commit comments

Comments
 (0)