|
| 1 | +""" |
| 2 | +Integration helpers between Eval Protocol evaluations and OpenAI RFT graders. |
| 3 | +
|
| 4 | +Currently provides: |
| 5 | +- build_python_grader_from_evaluation_test: turn an evaluation-style function into |
| 6 | + an OpenAI Python grader spec ({"type": "python", "source": ...}). |
| 7 | +""" |
| 8 | + |
| 9 | +import ast |
| 10 | +import inspect |
| 11 | +import textwrap |
| 12 | + |
| 13 | + |
| 14 | +def build_python_grader_from_evaluation_test(test_fn) -> dict: |
| 15 | + """ |
| 16 | + Return an OpenAI Python grader spec from an Eval Protocol-style evaluation function. |
| 17 | +
|
| 18 | + Assumptions: |
| 19 | + - `test_fn` is either: |
| 20 | + * the core evaluation function, or |
| 21 | + * an @evaluation_test-decorated function that carries `_origin_func`. |
| 22 | + Its effective signature looks like: |
| 23 | +
|
| 24 | + def my_eval(row, **kwargs) -> EvaluateResult | float | EvaluationRow |
| 25 | +
|
| 26 | + - The function treats `row` as an `EvaluationRow` and only relies on attributes |
| 27 | + we provide in the duck-typed stand-in: |
| 28 | + * row.ground_truth |
| 29 | + * row.messages |
| 30 | + * row.item (raw item dict) |
| 31 | + * row.sample (raw sample dict) |
| 32 | +
|
| 33 | + - We map OpenAI's (sample, item) into that duck-typed `EvaluationRow` as follows: |
| 34 | + * item["reference_answer"] -> row.ground_truth |
| 35 | + * item["messages"] (if present) -> row.messages (normalized to Message-like objects) |
| 36 | + * sample["output_text"] -> appended as the last assistant message in row.messages |
| 37 | + * the original dicts are also available via row.item / row.sample |
| 38 | +
|
| 39 | + - The function returns either: |
| 40 | + * a numeric score, or |
| 41 | + * an object/dict with a `score` field, or |
| 42 | + * an EvaluationRow/EvaluateResult-like object with `.evaluation_result.score`. |
| 43 | + """ |
| 44 | + |
| 45 | + # If the user passed an @evaluation_test wrapper, try to recover the original function |
| 46 | + origin = getattr(test_fn, "_origin_func", test_fn) |
| 47 | + |
| 48 | + # Get the source of the original function |
| 49 | + src = inspect.getsource(origin) |
| 50 | + src = textwrap.dedent(src) |
| 51 | + |
| 52 | + # Parse into AST so we can safely strip decorators and type annotations |
| 53 | + tree = ast.parse(src) |
| 54 | + |
| 55 | + class _StripAnnotationsAndDecorators(ast.NodeTransformer): |
| 56 | + def visit_FunctionDef(self, node: ast.FunctionDef) -> ast.AST: |
| 57 | + # Drop all decorators (e.g., @evaluation_test) |
| 58 | + node.decorator_list = [] |
| 59 | + # Remove return type annotation |
| 60 | + node.returns = None |
| 61 | + self.generic_visit(node) |
| 62 | + return node |
| 63 | + |
| 64 | + def visit_AsyncFunctionDef(self, node: ast.AsyncFunctionDef) -> ast.AST: |
| 65 | + node.decorator_list = [] |
| 66 | + node.returns = None |
| 67 | + self.generic_visit(node) |
| 68 | + return node |
| 69 | + |
| 70 | + def visit_arg(self, node: ast.arg) -> ast.AST: |
| 71 | + # Remove all parameter annotations (e.g., row: EvaluationRow) |
| 72 | + node.annotation = None |
| 73 | + return node |
| 74 | + |
| 75 | + transformer = _StripAnnotationsAndDecorators() |
| 76 | + tree = transformer.visit(tree) |
| 77 | + ast.fix_missing_locations(tree) |
| 78 | + |
| 79 | + # Find the first function definition and rename it to _ep_eval |
| 80 | + func_node: ast.AST | None = None |
| 81 | + for node in tree.body: |
| 82 | + if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)): |
| 83 | + func_node = node |
| 84 | + break |
| 85 | + |
| 86 | + if func_node is None: |
| 87 | + raise ValueError("Expected a function definition in test_fn source.") |
| 88 | + |
| 89 | + func_node.name = "_ep_eval" |
| 90 | + |
| 91 | + # Turn the modified AST back into source |
| 92 | + src = ast.unparse(tree) |
| 93 | + |
| 94 | + # Helper code that will live *inside* the grader source |
| 95 | + helper = """ |
| 96 | +from typing import Any, Dict |
| 97 | +from types import SimpleNamespace |
| 98 | +
|
| 99 | +
|
| 100 | +class EvaluationRow(SimpleNamespace): |
| 101 | + \"\"\"Minimal duck-typed stand-in for an evaluation row. |
| 102 | +
|
| 103 | + Extend this with whatever attributes your eval logic uses. |
| 104 | + \"\"\" |
| 105 | + pass |
| 106 | +
|
| 107 | +
|
| 108 | +class EvaluateResult(SimpleNamespace): |
| 109 | + \"\"\"Simple stand-in for Eval Protocol's EvaluateResult. |
| 110 | +
|
| 111 | + This lets evaluation-style functions that construct EvaluateResult(score=...) |
| 112 | + run inside the Python grader sandbox without importing eval_protocol. |
| 113 | + \"\"\" |
| 114 | +
|
| 115 | + def __init__(self, score: float, **kwargs: Any) -> None: |
| 116 | + super().__init__(score=score, **kwargs) |
| 117 | +
|
| 118 | +
|
| 119 | +class Message(SimpleNamespace): |
| 120 | + \"\"\"Duck-typed stand-in for eval_protocol.models.Message (role/content).\"\"\" |
| 121 | + pass |
| 122 | +
|
| 123 | +
|
| 124 | +def _build_row(sample: Dict[str, Any], item: Dict[str, Any]) -> EvaluationRow: |
| 125 | + # Start from any item-provided messages (EP-style), defaulting to []. |
| 126 | + raw_messages = item.get("messages") or [] |
| 127 | + normalized_messages = [] |
| 128 | + for m in raw_messages: |
| 129 | + if isinstance(m, dict): |
| 130 | + normalized_messages.append( |
| 131 | + Message( |
| 132 | + role=m.get("role"), |
| 133 | + content=m.get("content"), |
| 134 | + ) |
| 135 | + ) |
| 136 | + else: |
| 137 | + # Already Message-like; rely on duck typing (must have role/content) |
| 138 | + normalized_messages.append(m) |
| 139 | +
|
| 140 | + reference = item.get("reference_answer") |
| 141 | + prediction = sample.get("output_text") |
| 142 | +
|
| 143 | + # EP-style: ensure the model prediction is present as the last assistant message |
| 144 | + if prediction is not None: |
| 145 | + normalized_messages = list(normalized_messages) # shallow copy |
| 146 | + normalized_messages.append(Message(role="assistant", content=prediction)) |
| 147 | +
|
| 148 | + return EvaluationRow( |
| 149 | + ground_truth=reference, |
| 150 | + messages=normalized_messages, |
| 151 | + item=item, |
| 152 | + sample=sample, |
| 153 | + ) |
| 154 | +
|
| 155 | +
|
| 156 | +def grade(sample: Dict[str, Any], item: Dict[str, Any]) -> float: |
| 157 | + row = _build_row(sample, item) |
| 158 | + result = _ep_eval(row=row) |
| 159 | +
|
| 160 | + # Try to normalize different result shapes into a float score |
| 161 | + try: |
| 162 | + from collections.abc import Mapping |
| 163 | +
|
| 164 | + if isinstance(result, (int, float)): |
| 165 | + return float(result) |
| 166 | +
|
| 167 | + # EvaluateResult-like object with .score |
| 168 | + if hasattr(result, "score"): |
| 169 | + return float(result.score) |
| 170 | +
|
| 171 | + # EvaluationRow-like object with .evaluation_result.score |
| 172 | + eval_res = getattr(result, "evaluation_result", None) |
| 173 | + if eval_res is not None: |
| 174 | + if isinstance(eval_res, Mapping): |
| 175 | + if "score" in eval_res: |
| 176 | + return float(eval_res["score"]) |
| 177 | + elif hasattr(eval_res, "score"): |
| 178 | + return float(eval_res.score) |
| 179 | +
|
| 180 | + # Dict-like with score |
| 181 | + if isinstance(result, Mapping) and "score" in result: |
| 182 | + return float(result["score"]) |
| 183 | + except Exception: |
| 184 | + pass |
| 185 | +
|
| 186 | + return 0.0 |
| 187 | +""" |
| 188 | + |
| 189 | + full_source = src + "\n\n" + textwrap.dedent(helper) |
| 190 | + return {"type": "python", "source": full_source} |
0 commit comments