Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
256 changes: 208 additions & 48 deletions eval_protocol/cli_commands/create_rft.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import json
import os
import sys
import time
import argparse
from typing import Any, Dict, Optional

from ..auth import (
Expand All @@ -11,13 +13,8 @@
)
from ..fireworks_rft import (
_map_api_host_to_app_host,
build_default_dataset_id,
build_default_output_model,
create_dataset_from_jsonl,
create_reinforcement_fine_tuning_job,
detect_dataset_builder,
load_evaluator_trace,
materialize_dataset_via_builder,
)
from .upload import _discover_tests, _normalize_evaluator_id, _resolve_entry_to_qual_and_source

Expand Down Expand Up @@ -58,6 +55,129 @@ def _print_links(evaluator_id: str, dataset_id: str, job_name: Optional[str]) ->
pass


def _auto_find_jsonl(cwd: str) -> Optional[str]:
"""Find a reasonable JSONL dataset file in the current project.

Priority order:
- dataset.jsonl in cwd
- data/dataset.jsonl
- first *.jsonl under cwd (depth-first, skipping common vendor/venv/build dirs)
Returns a RELATIVE path from cwd if possible.
"""
# Direct candidates
direct_candidates = [
os.path.join(cwd, "dataset.jsonl"),
os.path.join(cwd, "data", "dataset.jsonl"),
]
for p in direct_candidates:
if os.path.isfile(p):
try:
return os.path.relpath(p, cwd)
except Exception:
return p

# Walk and find any .jsonl
skip_dirs = {".venv", "venv", "node_modules", "dist", "build", "__pycache__", ".git", "vendor"}
for dirpath, dirnames, filenames in os.walk(cwd):
# prune
dirnames[:] = [d for d in dirnames if d not in skip_dirs and not d.startswith(".")]
for name in sorted(filenames):
if name.endswith(".jsonl"):
candidate = os.path.join(dirpath, name)
try:
return os.path.relpath(candidate, cwd)
except Exception:
return candidate
return None


def _extract_jsonl_from_dataloader(test_file_path: str, test_func_name: str) -> Optional[str]:
"""Import the test module and extract a JSONL path from data_loaders param if present.

Looks for a pytest.mark.parametrize with argnames containing 'data_loaders' and attempts to
find an object with attribute 'jsonl_path'. If a relative path is found, it is resolved
relative to the directory of the test file.
"""
try:
import importlib.util
from pathlib import Path

spec = importlib.util.spec_from_file_location(Path(test_file_path).stem, test_file_path)
if not spec or not spec.loader:
return None
module = importlib.util.module_from_spec(spec)
sys.modules[spec.name] = module
spec.loader.exec_module(module) # type: ignore[attr-defined]
if not hasattr(module, test_func_name):
return None
wrapper = getattr(module, test_func_name)
marks = getattr(wrapper, "pytestmark", [])
for m in marks:
if getattr(m, "name", "") == "parametrize":
kwargs = getattr(m, "kwargs", {})
argnames = kwargs.get("argnames", (m.args[0] if m.args else []))
argvalues = kwargs.get("argvalues", (m.args[1] if len(m.args) > 1 else []))
# Normalize argnames to list
if isinstance(argnames, str):
names_list = [n.strip() for n in argnames.split(",") if n.strip()]
else:
names_list = list(argnames)
if "data_loaders" not in names_list:
continue
idx = names_list.index("data_loaders")
# argvalues is a list of tuples/values aligned with argnames
for val in argvalues:
# Normalize to tuple
if not isinstance(val, (tuple, list)):
params = (val,)
else:
params = tuple(val)
if idx >= len(params):
continue
dataloaders_obj = params[idx]
# May be a list or single loader
candidates = (
list(dataloaders_obj) if isinstance(dataloaders_obj, (list, tuple)) else [dataloaders_obj]
)
for dl in candidates:
jsonl_path = getattr(dl, "jsonl_path", None)
if isinstance(jsonl_path, str) and jsonl_path:
if os.path.isabs(jsonl_path):
return jsonl_path
base_dir = os.path.dirname(os.path.abspath(test_file_path))
return os.path.abspath(os.path.join(base_dir, jsonl_path))
return None
except Exception:
return None


def _build_trimmed_dataset_id(evaluator_id: str) -> str:
"""Build a dataset id derived from evaluator_id, trimmed to 63 chars.

Format: <normalized-base>-dataset-YYYYMMDDHHMMSS, where base is trimmed to fit.
"""
# Normalize base similarly to evaluator id rules
from .upload import _normalize_evaluator_id # local import to avoid cycle at module import time

base = _normalize_evaluator_id(evaluator_id)
suffix = f"-dataset-{time.strftime('%Y%m%d%H%M%S')}"
max_total = 63
max_base_len = max_total - len(suffix)
if max_base_len < 1:
max_base_len = 1
if len(base) > max_base_len:
base = base[:max_base_len].rstrip("-")
if not base:
base = "dataset"
# Ensure first char is a letter
if not base[0].isalpha():
base = f"eval-{base}"
if len(base) > max_base_len:
base = base[:max_base_len]
base = base.rstrip("-") or "dataset"
return f"{base}{suffix}"


def _auto_select_evaluator_id(cwd: str) -> Optional[str]:
# Try local traces
traces_dir = os.path.join(cwd, ".eval_protocol", "evaluators")
Expand Down Expand Up @@ -101,75 +221,117 @@ def create_rft_command(args) -> int:
print("Error: Could not infer evaluator id. Provide --evaluator-id or run 'eval-protocol upload' first.")
return 1

# Resolve evaluator resource name via local trace
# trace = load_evaluator_trace(project_root, evaluator_id)
# if not trace or not isinstance(trace, dict):
# print(
# "Error: Evaluator trace not found. Run 'eval-protocol upload' first or provide --dataset-id/--dataset-jsonl and --evaluator-id."
# )
# return 1
# evaluator_resource_name = trace.get("evaluator_resource_name") or trace.get("name") or evaluator_id
evaluator_resource_name = evaluator_id
# Resolve evaluator resource name to fully-qualified format required by API
evaluator_resource_name = f"accounts/{account_id}/evaluators/{evaluator_id}"

# Ensure evaluator exists by invoking the upload flow programmatically
try:
from .upload import upload_command

tests = _discover_tests(project_root)
selected_entry: Optional[str] = None
if len(tests) == 1:
func_name = tests[0].qualname.split(".")[-1]
abs_path = os.path.abspath(tests[0].file_path)
try:
rel = os.path.relpath(abs_path, project_root)
except Exception:
rel = abs_path
selected_entry = f"{rel}::{func_name}"
else:
# Try to match evaluator_id to a discovered test's normalized ID
for t in tests:
func_name = t.qualname.split(".")[-1]
source_file_name = os.path.splitext(os.path.basename(t.file_path))[0]
candidate = _normalize_evaluator_id(f"{source_file_name}-{func_name}")
if candidate == evaluator_id:
abs_path = os.path.abspath(t.file_path)
try:
rel = os.path.relpath(abs_path, project_root)
except Exception:
rel = abs_path
selected_entry = f"{rel}::{func_name}"
break

upload_args = argparse.Namespace(
path=project_root,
entry=selected_entry,
id=evaluator_id,
display_name=None,
description=None,
force=False,
yes=True,
)
rc = upload_command(upload_args)
if rc == 0:
print(f"✓ Uploaded/ensured evaluator: {evaluator_id}")
else:
print("Warning: Evaluator upload did not complete successfully; proceeding to RFT creation.")
except Exception as e:
print(f"Warning: Failed to upload evaluator automatically: {e}")

# Determine dataset id and materialization path
dataset_id = getattr(args, "dataset_id", None)
dataset_jsonl = getattr(args, "dataset_jsonl", None)
dataset_display_name = getattr(args, "dataset_display_name", None)
dataset_builder = getattr(args, "dataset_builder", None)
dataset_builder = getattr(args, "dataset_builder", None) # accepted but unused in simplified flow

if not dataset_id:
# Try builder from args, else from trace detection
# TODO: build dataset from traces directly
# builder_spec = dataset_builder or trace.get("dataset_builder")
# if not builder_spec:
# # Attempt detect from metric_dir
# metric_dir = trace.get("metric_dir")
# if metric_dir:
# builder_spec = detect_dataset_builder(metric_dir)
# if not builder_spec:
# print(
# "Error: Could not determine dataset. Provide --dataset-id, --dataset-jsonl, or --dataset-builder."
# )
# return 1
# try:
# dataset_jsonl, count = materialize_dataset_via_builder(builder_spec)
# print(f"✓ Materialized dataset via builder ({builder_spec}): {count} rows → {dataset_jsonl}")
# except Exception as e:
# print(f"Error: dataset builder failed: {e}")
# return 1

# Prefer explicit --dataset-jsonl, else attempt to extract from data loader of the single discovered test
if not dataset_jsonl:
print("Error: Could not determine dataset. Provide --dataset-id or --dataset-jsonl.")
tests = _discover_tests(project_root)
if len(tests) == 1:
func_name = tests[0].qualname.split(".")[-1]
dataset_jsonl = _extract_jsonl_from_dataloader(tests[0].file_path, func_name)
if dataset_jsonl:
# Display relative path for readability
try:
rel = os.path.relpath(dataset_jsonl, project_root)
except Exception:
rel = dataset_jsonl
print(f"✓ Using JSONL from data loader: {rel}")
if not dataset_jsonl:
print(
"Error: Could not determine dataset. Provide --dataset-id or --dataset-jsonl, or ensure a JSONL-based data loader is used in your single discovered test."
)
return 1

inferred_dataset_id = build_default_dataset_id(evaluator_id)
inferred_dataset_id = _build_trimmed_dataset_id(evaluator_id)
if dry_run:
print("--dry-run: would create dataset and upload JSONL")
dataset_id = inferred_dataset_id
else:
try:
# Resolve dataset_jsonl path relative to CWD if needed
jsonl_path_for_upload = (
dataset_jsonl
if os.path.isabs(dataset_jsonl)
else os.path.abspath(os.path.join(project_root, dataset_jsonl))
)
dataset_id, _ = create_dataset_from_jsonl(
account_id=account_id,
api_key=api_key,
api_base=api_base,
dataset_id=inferred_dataset_id,
display_name=dataset_display_name or inferred_dataset_id,
jsonl_path=dataset_jsonl,
jsonl_path=jsonl_path_for_upload,
)
print(f"✓ Created and uploaded dataset: {dataset_id}")
except Exception as e:
print(f"Error creating/uploading dataset: {e}")
return 1

# Build training config/body
training_config: Dict[str, Any] = {}
if getattr(args, "base_model", None):
training_config["baseModel"] = args.base_model
# Ensure base model is explicitly provided for clarity
if not getattr(args, "base_model", None):
print(
"Error: --base-model is required. Please specify the base model resource id (e.g., accounts/{account}/models/<model_id>)."
)
return 1

training_config: Dict[str, Any] = {"baseModel": args.base_model}
if getattr(args, "warm_start_from", None):
training_config["warmStartFrom"] = args.warm_start_from
if "baseModel" not in training_config and "warmStartFrom" not in training_config:
# Provide a conservative default if neither is set
training_config["baseModel"] = "accounts/fireworks/models/llama-v3p1-8b-instruct"

# Optional hyperparameters
for key, arg_name in [
Expand Down Expand Up @@ -221,14 +383,12 @@ def create_rft_command(args) -> int:
"outputMetrics": None,
"mcpServer": None,
}
print("Show body:")
print(json.dumps(body, indent=2))
# Debug: print minimal summary
print(f"Prepared RFT job for evaluator '{evaluator_id}' using dataset '{dataset_id}'")
if getattr(args, "evaluation_dataset", None):
body["evaluationDataset"] = args.evaluation_dataset
if getattr(args, "output_model", None):
body.setdefault("trainingConfig", {})["outputModel"] = f"accounts/{account_id}/models/{args.output_model}"
else:
body.setdefault("trainingConfig", {})["outputModel"] = build_default_output_model(evaluator_id)

# Clean None fields to avoid noisy payloads
body = {k: v for k, v in body.items() if v is not None}
Expand Down
3 changes: 2 additions & 1 deletion eval_protocol/data_loader/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from .dynamic_data_loader import DynamicDataLoader
from .inline_data_loader import InlineDataLoader
from .jsonl_data_loader import EvaluationRowJsonlDataLoader

__all__ = ["DynamicDataLoader", "InlineDataLoader"]
__all__ = ["DynamicDataLoader", "InlineDataLoader", "EvaluationRowJsonlDataLoader"]
42 changes: 42 additions & 0 deletions eval_protocol/data_loader/jsonl_data_loader.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
from __future__ import annotations

import os
from dataclasses import dataclass
from collections.abc import Sequence

from eval_protocol.common_utils import load_jsonl
from eval_protocol.pytest.default_dataset_adapter import default_dataset_adapter
from eval_protocol.data_loader.models import (
DataLoaderResult,
DataLoaderVariant,
EvaluationDataLoader,
)


@dataclass(kw_only=True)
class EvaluationRowJsonlDataLoader(EvaluationDataLoader):
"""Data loader that reads EvaluationRows from a JSONL file path.

Each line of the JSONL file should be a serialized EvaluationRow dict.
The loader will construct EvaluationRow objects via the default dataset adapter.
"""

jsonl_path: str
id: str = "jsonl"
description: str | None = None

def variants(self) -> Sequence[DataLoaderVariant]:
def _load() -> DataLoaderResult:
path = self.jsonl_path
if not os.path.isabs(path):
path = os.path.abspath(path)
rows_json = load_jsonl(path)
eval_rows = default_dataset_adapter(rows_json)
return DataLoaderResult(
rows=eval_rows,
type=self.__class__.__name__,
variant_id=self.id,
variant_description=self.description,
)

return [_load]
Loading
Loading