diff --git a/eval_protocol/cli_commands/create_rft.py b/eval_protocol/cli_commands/create_rft.py index cb78fbae..9580340a 100644 --- a/eval_protocol/cli_commands/create_rft.py +++ b/eval_protocol/cli_commands/create_rft.py @@ -1,6 +1,8 @@ import json import os import sys +import time +import argparse from typing import Any, Dict, Optional from ..auth import ( @@ -11,13 +13,8 @@ ) from ..fireworks_rft import ( _map_api_host_to_app_host, - build_default_dataset_id, - build_default_output_model, create_dataset_from_jsonl, create_reinforcement_fine_tuning_job, - detect_dataset_builder, - load_evaluator_trace, - materialize_dataset_via_builder, ) from .upload import _discover_tests, _normalize_evaluator_id, _resolve_entry_to_qual_and_source @@ -58,6 +55,129 @@ def _print_links(evaluator_id: str, dataset_id: str, job_name: Optional[str]) -> pass +def _auto_find_jsonl(cwd: str) -> Optional[str]: + """Find a reasonable JSONL dataset file in the current project. + + Priority order: + - dataset.jsonl in cwd + - data/dataset.jsonl + - first *.jsonl under cwd (depth-first, skipping common vendor/venv/build dirs) + Returns a RELATIVE path from cwd if possible. + """ + # Direct candidates + direct_candidates = [ + os.path.join(cwd, "dataset.jsonl"), + os.path.join(cwd, "data", "dataset.jsonl"), + ] + for p in direct_candidates: + if os.path.isfile(p): + try: + return os.path.relpath(p, cwd) + except Exception: + return p + + # Walk and find any .jsonl + skip_dirs = {".venv", "venv", "node_modules", "dist", "build", "__pycache__", ".git", "vendor"} + for dirpath, dirnames, filenames in os.walk(cwd): + # prune + dirnames[:] = [d for d in dirnames if d not in skip_dirs and not d.startswith(".")] + for name in sorted(filenames): + if name.endswith(".jsonl"): + candidate = os.path.join(dirpath, name) + try: + return os.path.relpath(candidate, cwd) + except Exception: + return candidate + return None + + +def _extract_jsonl_from_dataloader(test_file_path: str, test_func_name: str) -> Optional[str]: + """Import the test module and extract a JSONL path from data_loaders param if present. + + Looks for a pytest.mark.parametrize with argnames containing 'data_loaders' and attempts to + find an object with attribute 'jsonl_path'. If a relative path is found, it is resolved + relative to the directory of the test file. + """ + try: + import importlib.util + from pathlib import Path + + spec = importlib.util.spec_from_file_location(Path(test_file_path).stem, test_file_path) + if not spec or not spec.loader: + return None + module = importlib.util.module_from_spec(spec) + sys.modules[spec.name] = module + spec.loader.exec_module(module) # type: ignore[attr-defined] + if not hasattr(module, test_func_name): + return None + wrapper = getattr(module, test_func_name) + marks = getattr(wrapper, "pytestmark", []) + for m in marks: + if getattr(m, "name", "") == "parametrize": + kwargs = getattr(m, "kwargs", {}) + argnames = kwargs.get("argnames", (m.args[0] if m.args else [])) + argvalues = kwargs.get("argvalues", (m.args[1] if len(m.args) > 1 else [])) + # Normalize argnames to list + if isinstance(argnames, str): + names_list = [n.strip() for n in argnames.split(",") if n.strip()] + else: + names_list = list(argnames) + if "data_loaders" not in names_list: + continue + idx = names_list.index("data_loaders") + # argvalues is a list of tuples/values aligned with argnames + for val in argvalues: + # Normalize to tuple + if not isinstance(val, (tuple, list)): + params = (val,) + else: + params = tuple(val) + if idx >= len(params): + continue + dataloaders_obj = params[idx] + # May be a list or single loader + candidates = ( + list(dataloaders_obj) if isinstance(dataloaders_obj, (list, tuple)) else [dataloaders_obj] + ) + for dl in candidates: + jsonl_path = getattr(dl, "jsonl_path", None) + if isinstance(jsonl_path, str) and jsonl_path: + if os.path.isabs(jsonl_path): + return jsonl_path + base_dir = os.path.dirname(os.path.abspath(test_file_path)) + return os.path.abspath(os.path.join(base_dir, jsonl_path)) + return None + except Exception: + return None + + +def _build_trimmed_dataset_id(evaluator_id: str) -> str: + """Build a dataset id derived from evaluator_id, trimmed to 63 chars. + + Format: -dataset-YYYYMMDDHHMMSS, where base is trimmed to fit. + """ + # Normalize base similarly to evaluator id rules + from .upload import _normalize_evaluator_id # local import to avoid cycle at module import time + + base = _normalize_evaluator_id(evaluator_id) + suffix = f"-dataset-{time.strftime('%Y%m%d%H%M%S')}" + max_total = 63 + max_base_len = max_total - len(suffix) + if max_base_len < 1: + max_base_len = 1 + if len(base) > max_base_len: + base = base[:max_base_len].rstrip("-") + if not base: + base = "dataset" + # Ensure first char is a letter + if not base[0].isalpha(): + base = f"eval-{base}" + if len(base) > max_base_len: + base = base[:max_base_len] + base = base.rstrip("-") or "dataset" + return f"{base}{suffix}" + + def _auto_select_evaluator_id(cwd: str) -> Optional[str]: # Try local traces traces_dir = os.path.join(cwd, ".eval_protocol", "evaluators") @@ -101,60 +221,100 @@ def create_rft_command(args) -> int: print("Error: Could not infer evaluator id. Provide --evaluator-id or run 'eval-protocol upload' first.") return 1 - # Resolve evaluator resource name via local trace - # trace = load_evaluator_trace(project_root, evaluator_id) - # if not trace or not isinstance(trace, dict): - # print( - # "Error: Evaluator trace not found. Run 'eval-protocol upload' first or provide --dataset-id/--dataset-jsonl and --evaluator-id." - # ) - # return 1 - # evaluator_resource_name = trace.get("evaluator_resource_name") or trace.get("name") or evaluator_id - evaluator_resource_name = evaluator_id + # Resolve evaluator resource name to fully-qualified format required by API + evaluator_resource_name = f"accounts/{account_id}/evaluators/{evaluator_id}" + + # Ensure evaluator exists by invoking the upload flow programmatically + try: + from .upload import upload_command + + tests = _discover_tests(project_root) + selected_entry: Optional[str] = None + if len(tests) == 1: + func_name = tests[0].qualname.split(".")[-1] + abs_path = os.path.abspath(tests[0].file_path) + try: + rel = os.path.relpath(abs_path, project_root) + except Exception: + rel = abs_path + selected_entry = f"{rel}::{func_name}" + else: + # Try to match evaluator_id to a discovered test's normalized ID + for t in tests: + func_name = t.qualname.split(".")[-1] + source_file_name = os.path.splitext(os.path.basename(t.file_path))[0] + candidate = _normalize_evaluator_id(f"{source_file_name}-{func_name}") + if candidate == evaluator_id: + abs_path = os.path.abspath(t.file_path) + try: + rel = os.path.relpath(abs_path, project_root) + except Exception: + rel = abs_path + selected_entry = f"{rel}::{func_name}" + break + + upload_args = argparse.Namespace( + path=project_root, + entry=selected_entry, + id=evaluator_id, + display_name=None, + description=None, + force=False, + yes=True, + ) + rc = upload_command(upload_args) + if rc == 0: + print(f"✓ Uploaded/ensured evaluator: {evaluator_id}") + else: + print("Warning: Evaluator upload did not complete successfully; proceeding to RFT creation.") + except Exception as e: + print(f"Warning: Failed to upload evaluator automatically: {e}") # Determine dataset id and materialization path dataset_id = getattr(args, "dataset_id", None) dataset_jsonl = getattr(args, "dataset_jsonl", None) dataset_display_name = getattr(args, "dataset_display_name", None) - dataset_builder = getattr(args, "dataset_builder", None) + dataset_builder = getattr(args, "dataset_builder", None) # accepted but unused in simplified flow if not dataset_id: - # Try builder from args, else from trace detection - # TODO: build dataset from traces directly - # builder_spec = dataset_builder or trace.get("dataset_builder") - # if not builder_spec: - # # Attempt detect from metric_dir - # metric_dir = trace.get("metric_dir") - # if metric_dir: - # builder_spec = detect_dataset_builder(metric_dir) - # if not builder_spec: - # print( - # "Error: Could not determine dataset. Provide --dataset-id, --dataset-jsonl, or --dataset-builder." - # ) - # return 1 - # try: - # dataset_jsonl, count = materialize_dataset_via_builder(builder_spec) - # print(f"✓ Materialized dataset via builder ({builder_spec}): {count} rows → {dataset_jsonl}") - # except Exception as e: - # print(f"Error: dataset builder failed: {e}") - # return 1 - + # Prefer explicit --dataset-jsonl, else attempt to extract from data loader of the single discovered test if not dataset_jsonl: - print("Error: Could not determine dataset. Provide --dataset-id or --dataset-jsonl.") + tests = _discover_tests(project_root) + if len(tests) == 1: + func_name = tests[0].qualname.split(".")[-1] + dataset_jsonl = _extract_jsonl_from_dataloader(tests[0].file_path, func_name) + if dataset_jsonl: + # Display relative path for readability + try: + rel = os.path.relpath(dataset_jsonl, project_root) + except Exception: + rel = dataset_jsonl + print(f"✓ Using JSONL from data loader: {rel}") + if not dataset_jsonl: + print( + "Error: Could not determine dataset. Provide --dataset-id or --dataset-jsonl, or ensure a JSONL-based data loader is used in your single discovered test." + ) return 1 - inferred_dataset_id = build_default_dataset_id(evaluator_id) + inferred_dataset_id = _build_trimmed_dataset_id(evaluator_id) if dry_run: print("--dry-run: would create dataset and upload JSONL") dataset_id = inferred_dataset_id else: try: + # Resolve dataset_jsonl path relative to CWD if needed + jsonl_path_for_upload = ( + dataset_jsonl + if os.path.isabs(dataset_jsonl) + else os.path.abspath(os.path.join(project_root, dataset_jsonl)) + ) dataset_id, _ = create_dataset_from_jsonl( account_id=account_id, api_key=api_key, api_base=api_base, dataset_id=inferred_dataset_id, display_name=dataset_display_name or inferred_dataset_id, - jsonl_path=dataset_jsonl, + jsonl_path=jsonl_path_for_upload, ) print(f"✓ Created and uploaded dataset: {dataset_id}") except Exception as e: @@ -162,14 +322,16 @@ def create_rft_command(args) -> int: return 1 # Build training config/body - training_config: Dict[str, Any] = {} - if getattr(args, "base_model", None): - training_config["baseModel"] = args.base_model + # Ensure base model is explicitly provided for clarity + if not getattr(args, "base_model", None): + print( + "Error: --base-model is required. Please specify the base model resource id (e.g., accounts/{account}/models/)." + ) + return 1 + + training_config: Dict[str, Any] = {"baseModel": args.base_model} if getattr(args, "warm_start_from", None): training_config["warmStartFrom"] = args.warm_start_from - if "baseModel" not in training_config and "warmStartFrom" not in training_config: - # Provide a conservative default if neither is set - training_config["baseModel"] = "accounts/fireworks/models/llama-v3p1-8b-instruct" # Optional hyperparameters for key, arg_name in [ @@ -221,14 +383,12 @@ def create_rft_command(args) -> int: "outputMetrics": None, "mcpServer": None, } - print("Show body:") - print(json.dumps(body, indent=2)) + # Debug: print minimal summary + print(f"Prepared RFT job for evaluator '{evaluator_id}' using dataset '{dataset_id}'") if getattr(args, "evaluation_dataset", None): body["evaluationDataset"] = args.evaluation_dataset if getattr(args, "output_model", None): body.setdefault("trainingConfig", {})["outputModel"] = f"accounts/{account_id}/models/{args.output_model}" - else: - body.setdefault("trainingConfig", {})["outputModel"] = build_default_output_model(evaluator_id) # Clean None fields to avoid noisy payloads body = {k: v for k, v in body.items() if v is not None} diff --git a/eval_protocol/data_loader/__init__.py b/eval_protocol/data_loader/__init__.py index 4c92b023..fbd4d3ce 100644 --- a/eval_protocol/data_loader/__init__.py +++ b/eval_protocol/data_loader/__init__.py @@ -1,4 +1,5 @@ from .dynamic_data_loader import DynamicDataLoader from .inline_data_loader import InlineDataLoader +from .jsonl_data_loader import EvaluationRowJsonlDataLoader -__all__ = ["DynamicDataLoader", "InlineDataLoader"] +__all__ = ["DynamicDataLoader", "InlineDataLoader", "EvaluationRowJsonlDataLoader"] diff --git a/eval_protocol/data_loader/jsonl_data_loader.py b/eval_protocol/data_loader/jsonl_data_loader.py new file mode 100644 index 00000000..dd607f09 --- /dev/null +++ b/eval_protocol/data_loader/jsonl_data_loader.py @@ -0,0 +1,42 @@ +from __future__ import annotations + +import os +from dataclasses import dataclass +from collections.abc import Sequence + +from eval_protocol.common_utils import load_jsonl +from eval_protocol.pytest.default_dataset_adapter import default_dataset_adapter +from eval_protocol.data_loader.models import ( + DataLoaderResult, + DataLoaderVariant, + EvaluationDataLoader, +) + + +@dataclass(kw_only=True) +class EvaluationRowJsonlDataLoader(EvaluationDataLoader): + """Data loader that reads EvaluationRows from a JSONL file path. + + Each line of the JSONL file should be a serialized EvaluationRow dict. + The loader will construct EvaluationRow objects via the default dataset adapter. + """ + + jsonl_path: str + id: str = "jsonl" + description: str | None = None + + def variants(self) -> Sequence[DataLoaderVariant]: + def _load() -> DataLoaderResult: + path = self.jsonl_path + if not os.path.isabs(path): + path = os.path.abspath(path) + rows_json = load_jsonl(path) + eval_rows = default_dataset_adapter(rows_json) + return DataLoaderResult( + rows=eval_rows, + type=self.__class__.__name__, + variant_id=self.id, + variant_description=self.description, + ) + + return [_load] diff --git a/eval_protocol/fireworks_rft.py b/eval_protocol/fireworks_rft.py index 3fd44eaa..6bd2e62e 100644 --- a/eval_protocol/fireworks_rft.py +++ b/eval_protocol/fireworks_rft.py @@ -18,12 +18,20 @@ def _map_api_host_to_app_host(api_base: str) -> str: from urllib.parse import urlparse parsed = urlparse(api_base) - host = parsed.netloc or parsed.path + host = (parsed.netloc or parsed.path).lower() + scheme = parsed.scheme or "https" + + # Explicit mappings first if host.startswith("dev.api.fireworks.ai"): - return f"{parsed.scheme or 'https'}://dev.fireworks.ai" + return f"{scheme}://dev.fireworks.ai" + if host == "staging.api.fireworks.ai" or host == "api.fireworks.ai": + return f"{scheme}://app.fireworks.ai" + + # Generic mapping: api.<...> → app.<...> if host.startswith("api."): - return f"{parsed.scheme or 'https'}://{host.replace('api.', 'app.', 1)}" - return f"{parsed.scheme or 'https'}://{host}" + return f"{scheme}://{host.replace('api.', 'app.', 1)}" + + return f"{scheme}://{host}" except Exception: return "https://app.fireworks.ai" diff --git a/tests/pytest/gsm8k/test_pytest_math_example.py b/tests/pytest/gsm8k/test_pytest_math_example.py index 0ac994c7..961ff479 100644 --- a/tests/pytest/gsm8k/test_pytest_math_example.py +++ b/tests/pytest/gsm8k/test_pytest_math_example.py @@ -1,19 +1,34 @@ import re from eval_protocol.models import EvaluateResult, EvaluationRow, MetricResult, Message from eval_protocol.pytest import SingleTurnRolloutProcessor, evaluation_test +import os +from eval_protocol.data_loader.jsonl_data_loader import EvaluationRowJsonlDataLoader from typing import List, Dict, Any, Optional def extract_answer_digits(ground_truth: str) -> Optional[str]: """ - Extract the digits from the answer string. + Extract the first sequence of digits within ... tags. + + Returns None if tags are missing or no digits are found. """ - answer_string = ground_truth.split("")[1].split("")[0] - return re.search(r"(\d+)", answer_string).group(1) if answer_string else None + if not ground_truth: + return None + + match = re.search(r"(.*?)", ground_truth, flags=re.IGNORECASE | re.DOTALL) + if not match: + return None + + answer_string = match.group(1) + digits_match = re.search(r"(\d+)", answer_string) + return digits_match.group(1) if digits_match else None + + +JSONL_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), "../../../development/gsm8k_sample.jsonl")) @evaluation_test( - input_dataset=["development/gsm8k_sample.jsonl"], + data_loaders=EvaluationRowJsonlDataLoader(jsonl_path=JSONL_PATH), completion_params=[{"temperature": 0.0, "model": "fireworks_ai/accounts/fireworks/models/gpt-oss-120b"}], max_dataset_rows=5, passed_threshold=0.0,