Skip to content

Commit 32c7916

Browse files
committed
single command experience
1 parent 9f352ed commit 32c7916

File tree

5 files changed

+283
-57
lines changed

5 files changed

+283
-57
lines changed

eval_protocol/cli_commands/create_rft.py

Lines changed: 208 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
import json
22
import os
33
import sys
4+
import time
5+
import argparse
46
from typing import Any, Dict, Optional
57

68
from ..auth import (
@@ -11,13 +13,8 @@
1113
)
1214
from ..fireworks_rft import (
1315
_map_api_host_to_app_host,
14-
build_default_dataset_id,
15-
build_default_output_model,
1616
create_dataset_from_jsonl,
1717
create_reinforcement_fine_tuning_job,
18-
detect_dataset_builder,
19-
load_evaluator_trace,
20-
materialize_dataset_via_builder,
2118
)
2219
from .upload import _discover_tests, _normalize_evaluator_id, _resolve_entry_to_qual_and_source
2320

@@ -58,6 +55,129 @@ def _print_links(evaluator_id: str, dataset_id: str, job_name: Optional[str]) ->
5855
pass
5956

6057

58+
def _auto_find_jsonl(cwd: str) -> Optional[str]:
59+
"""Find a reasonable JSONL dataset file in the current project.
60+
61+
Priority order:
62+
- dataset.jsonl in cwd
63+
- data/dataset.jsonl
64+
- first *.jsonl under cwd (depth-first, skipping common vendor/venv/build dirs)
65+
Returns a RELATIVE path from cwd if possible.
66+
"""
67+
# Direct candidates
68+
direct_candidates = [
69+
os.path.join(cwd, "dataset.jsonl"),
70+
os.path.join(cwd, "data", "dataset.jsonl"),
71+
]
72+
for p in direct_candidates:
73+
if os.path.isfile(p):
74+
try:
75+
return os.path.relpath(p, cwd)
76+
except Exception:
77+
return p
78+
79+
# Walk and find any .jsonl
80+
skip_dirs = {".venv", "venv", "node_modules", "dist", "build", "__pycache__", ".git", "vendor"}
81+
for dirpath, dirnames, filenames in os.walk(cwd):
82+
# prune
83+
dirnames[:] = [d for d in dirnames if d not in skip_dirs and not d.startswith(".")]
84+
for name in sorted(filenames):
85+
if name.endswith(".jsonl"):
86+
candidate = os.path.join(dirpath, name)
87+
try:
88+
return os.path.relpath(candidate, cwd)
89+
except Exception:
90+
return candidate
91+
return None
92+
93+
94+
def _extract_jsonl_from_dataloader(test_file_path: str, test_func_name: str) -> Optional[str]:
95+
"""Import the test module and extract a JSONL path from data_loaders param if present.
96+
97+
Looks for a pytest.mark.parametrize with argnames containing 'data_loaders' and attempts to
98+
find an object with attribute 'jsonl_path'. If a relative path is found, it is resolved
99+
relative to the directory of the test file.
100+
"""
101+
try:
102+
import importlib.util
103+
from pathlib import Path
104+
105+
spec = importlib.util.spec_from_file_location(Path(test_file_path).stem, test_file_path)
106+
if not spec or not spec.loader:
107+
return None
108+
module = importlib.util.module_from_spec(spec)
109+
sys.modules[spec.name] = module
110+
spec.loader.exec_module(module) # type: ignore[attr-defined]
111+
if not hasattr(module, test_func_name):
112+
return None
113+
wrapper = getattr(module, test_func_name)
114+
marks = getattr(wrapper, "pytestmark", [])
115+
for m in marks:
116+
if getattr(m, "name", "") == "parametrize":
117+
kwargs = getattr(m, "kwargs", {})
118+
argnames = kwargs.get("argnames", (m.args[0] if m.args else []))
119+
argvalues = kwargs.get("argvalues", (m.args[1] if len(m.args) > 1 else []))
120+
# Normalize argnames to list
121+
if isinstance(argnames, str):
122+
names_list = [n.strip() for n in argnames.split(",") if n.strip()]
123+
else:
124+
names_list = list(argnames)
125+
if "data_loaders" not in names_list:
126+
continue
127+
idx = names_list.index("data_loaders")
128+
# argvalues is a list of tuples/values aligned with argnames
129+
for val in argvalues:
130+
# Normalize to tuple
131+
if not isinstance(val, (tuple, list)):
132+
params = (val,)
133+
else:
134+
params = tuple(val)
135+
if idx >= len(params):
136+
continue
137+
dataloaders_obj = params[idx]
138+
# May be a list or single loader
139+
candidates = (
140+
list(dataloaders_obj) if isinstance(dataloaders_obj, (list, tuple)) else [dataloaders_obj]
141+
)
142+
for dl in candidates:
143+
jsonl_path = getattr(dl, "jsonl_path", None)
144+
if isinstance(jsonl_path, str) and jsonl_path:
145+
if os.path.isabs(jsonl_path):
146+
return jsonl_path
147+
base_dir = os.path.dirname(os.path.abspath(test_file_path))
148+
return os.path.abspath(os.path.join(base_dir, jsonl_path))
149+
return None
150+
except Exception:
151+
return None
152+
153+
154+
def _build_trimmed_dataset_id(evaluator_id: str) -> str:
155+
"""Build a dataset id derived from evaluator_id, trimmed to 63 chars.
156+
157+
Format: <normalized-base>-dataset-YYYYMMDDHHMMSS, where base is trimmed to fit.
158+
"""
159+
# Normalize base similarly to evaluator id rules
160+
from .upload import _normalize_evaluator_id # local import to avoid cycle at module import time
161+
162+
base = _normalize_evaluator_id(evaluator_id)
163+
suffix = f"-dataset-{time.strftime('%Y%m%d%H%M%S')}"
164+
max_total = 63
165+
max_base_len = max_total - len(suffix)
166+
if max_base_len < 1:
167+
max_base_len = 1
168+
if len(base) > max_base_len:
169+
base = base[:max_base_len].rstrip("-")
170+
if not base:
171+
base = "dataset"
172+
# Ensure first char is a letter
173+
if not base[0].isalpha():
174+
base = f"eval-{base}"
175+
if len(base) > max_base_len:
176+
base = base[:max_base_len]
177+
base = base.rstrip("-") or "dataset"
178+
return f"{base}{suffix}"
179+
180+
61181
def _auto_select_evaluator_id(cwd: str) -> Optional[str]:
62182
# Try local traces
63183
traces_dir = os.path.join(cwd, ".eval_protocol", "evaluators")
@@ -101,75 +221,117 @@ def create_rft_command(args) -> int:
101221
print("Error: Could not infer evaluator id. Provide --evaluator-id or run 'eval-protocol upload' first.")
102222
return 1
103223

104-
# Resolve evaluator resource name via local trace
105-
# trace = load_evaluator_trace(project_root, evaluator_id)
106-
# if not trace or not isinstance(trace, dict):
107-
# print(
108-
# "Error: Evaluator trace not found. Run 'eval-protocol upload' first or provide --dataset-id/--dataset-jsonl and --evaluator-id."
109-
# )
110-
# return 1
111-
# evaluator_resource_name = trace.get("evaluator_resource_name") or trace.get("name") or evaluator_id
112-
evaluator_resource_name = evaluator_id
224+
# Resolve evaluator resource name to fully-qualified format required by API
225+
evaluator_resource_name = f"accounts/{account_id}/evaluators/{evaluator_id}"
226+
227+
# Ensure evaluator exists by invoking the upload flow programmatically
228+
try:
229+
from .upload import upload_command
230+
231+
tests = _discover_tests(project_root)
232+
selected_entry: Optional[str] = None
233+
if len(tests) == 1:
234+
func_name = tests[0].qualname.split(".")[-1]
235+
abs_path = os.path.abspath(tests[0].file_path)
236+
try:
237+
rel = os.path.relpath(abs_path, project_root)
238+
except Exception:
239+
rel = abs_path
240+
selected_entry = f"{rel}::{func_name}"
241+
else:
242+
# Try to match evaluator_id to a discovered test's normalized ID
243+
for t in tests:
244+
func_name = t.qualname.split(".")[-1]
245+
source_file_name = os.path.splitext(os.path.basename(t.file_path))[0]
246+
candidate = _normalize_evaluator_id(f"{source_file_name}-{func_name}")
247+
if candidate == evaluator_id:
248+
abs_path = os.path.abspath(t.file_path)
249+
try:
250+
rel = os.path.relpath(abs_path, project_root)
251+
except Exception:
252+
rel = abs_path
253+
selected_entry = f"{rel}::{func_name}"
254+
break
255+
256+
upload_args = argparse.Namespace(
257+
path=project_root,
258+
entry=selected_entry,
259+
id=evaluator_id,
260+
display_name=None,
261+
description=None,
262+
force=False,
263+
yes=True,
264+
)
265+
rc = upload_command(upload_args)
266+
if rc == 0:
267+
print(f"✓ Uploaded/ensured evaluator: {evaluator_id}")
268+
else:
269+
print("Warning: Evaluator upload did not complete successfully; proceeding to RFT creation.")
270+
except Exception as e:
271+
print(f"Warning: Failed to upload evaluator automatically: {e}")
113272

114273
# Determine dataset id and materialization path
115274
dataset_id = getattr(args, "dataset_id", None)
116275
dataset_jsonl = getattr(args, "dataset_jsonl", None)
117276
dataset_display_name = getattr(args, "dataset_display_name", None)
118-
dataset_builder = getattr(args, "dataset_builder", None)
277+
dataset_builder = getattr(args, "dataset_builder", None) # accepted but unused in simplified flow
119278

120279
if not dataset_id:
121-
# Try builder from args, else from trace detection
122-
# TODO: build dataset from traces directly
123-
# builder_spec = dataset_builder or trace.get("dataset_builder")
124-
# if not builder_spec:
125-
# # Attempt detect from metric_dir
126-
# metric_dir = trace.get("metric_dir")
127-
# if metric_dir:
128-
# builder_spec = detect_dataset_builder(metric_dir)
129-
# if not builder_spec:
130-
# print(
131-
# "Error: Could not determine dataset. Provide --dataset-id, --dataset-jsonl, or --dataset-builder."
132-
# )
133-
# return 1
134-
# try:
135-
# dataset_jsonl, count = materialize_dataset_via_builder(builder_spec)
136-
# print(f"✓ Materialized dataset via builder ({builder_spec}): {count} rows → {dataset_jsonl}")
137-
# except Exception as e:
138-
# print(f"Error: dataset builder failed: {e}")
139-
# return 1
140-
280+
# Prefer explicit --dataset-jsonl, else attempt to extract from data loader of the single discovered test
141281
if not dataset_jsonl:
142-
print("Error: Could not determine dataset. Provide --dataset-id or --dataset-jsonl.")
282+
tests = _discover_tests(project_root)
283+
if len(tests) == 1:
284+
func_name = tests[0].qualname.split(".")[-1]
285+
dataset_jsonl = _extract_jsonl_from_dataloader(tests[0].file_path, func_name)
286+
if dataset_jsonl:
287+
# Display relative path for readability
288+
try:
289+
rel = os.path.relpath(dataset_jsonl, project_root)
290+
except Exception:
291+
rel = dataset_jsonl
292+
print(f"✓ Using JSONL from data loader: {rel}")
293+
if not dataset_jsonl:
294+
print(
295+
"Error: Could not determine dataset. Provide --dataset-id or --dataset-jsonl, or ensure a JSONL-based data loader is used in your single discovered test."
296+
)
143297
return 1
144298

145-
inferred_dataset_id = build_default_dataset_id(evaluator_id)
299+
inferred_dataset_id = _build_trimmed_dataset_id(evaluator_id)
146300
if dry_run:
147301
print("--dry-run: would create dataset and upload JSONL")
148302
dataset_id = inferred_dataset_id
149303
else:
150304
try:
305+
# Resolve dataset_jsonl path relative to CWD if needed
306+
jsonl_path_for_upload = (
307+
dataset_jsonl
308+
if os.path.isabs(dataset_jsonl)
309+
else os.path.abspath(os.path.join(project_root, dataset_jsonl))
310+
)
151311
dataset_id, _ = create_dataset_from_jsonl(
152312
account_id=account_id,
153313
api_key=api_key,
154314
api_base=api_base,
155315
dataset_id=inferred_dataset_id,
156316
display_name=dataset_display_name or inferred_dataset_id,
157-
jsonl_path=dataset_jsonl,
317+
jsonl_path=jsonl_path_for_upload,
158318
)
159319
print(f"✓ Created and uploaded dataset: {dataset_id}")
160320
except Exception as e:
161321
print(f"Error creating/uploading dataset: {e}")
162322
return 1
163323

164324
# Build training config/body
165-
training_config: Dict[str, Any] = {}
166-
if getattr(args, "base_model", None):
167-
training_config["baseModel"] = args.base_model
325+
# Ensure base model is explicitly provided for clarity
326+
if not getattr(args, "base_model", None):
327+
print(
328+
"Error: --base-model is required. Please specify the base model resource id (e.g., accounts/{account}/models/<model_id>)."
329+
)
330+
return 1
331+
332+
training_config: Dict[str, Any] = {"baseModel": args.base_model}
168333
if getattr(args, "warm_start_from", None):
169334
training_config["warmStartFrom"] = args.warm_start_from
170-
if "baseModel" not in training_config and "warmStartFrom" not in training_config:
171-
# Provide a conservative default if neither is set
172-
training_config["baseModel"] = "accounts/fireworks/models/llama-v3p1-8b-instruct"
173335

174336
# Optional hyperparameters
175337
for key, arg_name in [
@@ -221,14 +383,12 @@ def create_rft_command(args) -> int:
221383
"outputMetrics": None,
222384
"mcpServer": None,
223385
}
224-
print("Show body:")
225-
print(json.dumps(body, indent=2))
386+
# Debug: print minimal summary
387+
print(f"Prepared RFT job for evaluator '{evaluator_id}' using dataset '{dataset_id}'")
226388
if getattr(args, "evaluation_dataset", None):
227389
body["evaluationDataset"] = args.evaluation_dataset
228390
if getattr(args, "output_model", None):
229391
body.setdefault("trainingConfig", {})["outputModel"] = f"accounts/{account_id}/models/{args.output_model}"
230-
else:
231-
body.setdefault("trainingConfig", {})["outputModel"] = build_default_output_model(evaluator_id)
232392

233393
# Clean None fields to avoid noisy payloads
234394
body = {k: v for k, v in body.items() if v is not None}
Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from .dynamic_data_loader import DynamicDataLoader
22
from .inline_data_loader import InlineDataLoader
3+
from .jsonl_data_loader import EvaluationRowJsonlDataLoader
34

4-
__all__ = ["DynamicDataLoader", "InlineDataLoader"]
5+
__all__ = ["DynamicDataLoader", "InlineDataLoader", "EvaluationRowJsonlDataLoader"]
Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
from __future__ import annotations
2+
3+
import os
4+
from dataclasses import dataclass
5+
from collections.abc import Sequence
6+
7+
from eval_protocol.common_utils import load_jsonl
8+
from eval_protocol.pytest.default_dataset_adapter import default_dataset_adapter
9+
from eval_protocol.data_loader.models import (
10+
DataLoaderResult,
11+
DataLoaderVariant,
12+
EvaluationDataLoader,
13+
)
14+
15+
16+
@dataclass(kw_only=True)
17+
class EvaluationRowJsonlDataLoader(EvaluationDataLoader):
18+
"""Data loader that reads EvaluationRows from a JSONL file path.
19+
20+
Each line of the JSONL file should be a serialized EvaluationRow dict.
21+
The loader will construct EvaluationRow objects via the default dataset adapter.
22+
"""
23+
24+
jsonl_path: str
25+
id: str = "jsonl"
26+
description: str | None = None
27+
28+
def variants(self) -> Sequence[DataLoaderVariant]:
29+
def _load() -> DataLoaderResult:
30+
path = self.jsonl_path
31+
if not os.path.isabs(path):
32+
path = os.path.abspath(path)
33+
rows_json = load_jsonl(path)
34+
eval_rows = default_dataset_adapter(rows_json)
35+
return DataLoaderResult(
36+
rows=eval_rows,
37+
type=self.__class__.__name__,
38+
variant_id=self.id,
39+
variant_description=self.description,
40+
)
41+
42+
return [_load]

0 commit comments

Comments
 (0)