Skip to content

Commit 00352c9

Browse files
committed
auto convert from dict
1 parent 495ff63 commit 00352c9

File tree

2 files changed

+21
-4
lines changed

2 files changed

+21
-4
lines changed

eval_protocol/models.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from openai.types.chat.chat_completion_message_tool_call import (
1313
ChatCompletionMessageToolCall,
1414
)
15-
from pydantic import BaseModel, ConfigDict, Field
15+
from pydantic import BaseModel, ConfigDict, Field, field_validator
1616

1717
from eval_protocol.get_pep440_version import get_pep440_version
1818
from eval_protocol.human_id import generate_id
@@ -595,7 +595,7 @@ class EvaluationRow(BaseModel):
595595
supporting both row-wise batch evaluation and trajectory-based RL evaluation.
596596
"""
597597

598-
model_config = ConfigDict(extra="allow")
598+
model_config = ConfigDict(extra="allow", validate_assignment=True)
599599

600600
# Core OpenAI ChatCompletion compatible conversation data
601601
messages: List[Message] = Field(description="List of messages in the conversation. Also known as a trajectory.")
@@ -626,6 +626,17 @@ class EvaluationRow(BaseModel):
626626
default=None, description="The evaluation result for this row/trajectory."
627627
)
628628

629+
@field_validator("evaluation_result", mode="before")
630+
@classmethod
631+
def _coerce_evaluation_result(
632+
cls, value: EvaluateResult | dict[str, Any] | None
633+
) -> EvaluateResult | None:
634+
if value is None or isinstance(value, EvaluateResult):
635+
return value
636+
if isinstance(value, dict):
637+
return EvaluateResult(**value)
638+
return value
639+
629640
execution_metadata: ExecutionMetadata = Field(
630641
default_factory=lambda: ExecutionMetadata(run_id=None),
631642
description="Metadata about the execution of the evaluation.",

tests/test_models.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -198,8 +198,14 @@ def test_metric_result_dict_access():
198198
}
199199
assert set(metric.items()) == expected_items
200200

201-
# __iter__
202-
assert set(list(metric)) == {"score", "reason", "is_score_valid"}
201+
202+
def test_evaluation_row_accepts_dict_assignment_for_evaluation_result():
203+
row = dummy_row()
204+
row.evaluation_result = {"score": 0.6}
205+
206+
assert isinstance(row.evaluation_result, EvaluateResult)
207+
assert row.evaluation_result.score == 0.6
208+
assert row.evaluation_result.is_score_valid is True
203209

204210

205211
def test_evaluate_result_dict_access():

0 commit comments

Comments
 (0)