diff --git a/eval_protocol/models.py b/eval_protocol/models.py index 2804db59..e2bf355f 100644 --- a/eval_protocol/models.py +++ b/eval_protocol/models.py @@ -269,6 +269,11 @@ class Message(BaseModel): tool_calls: Optional[List[ChatCompletionMessageToolCall]] = None function_call: Optional[FunctionCall] = None control_plane_step: Optional[Dict[str, Any]] = None + weight: Optional[int] = None + + def dump_mdoel_for_chat_completion_request(self): + """Only keep chat completion accepted fields""" + return self.model_dump(exclude_none=True, exclude={"control_plane_step", "reasoning_content", "weight"}) @classmethod def model_validate(cls, obj, *args, **kwargs): diff --git a/tests/adapters/test_openai_responses_adapter.py b/tests/adapters/test_openai_responses_adapter.py index bf9edd98..6091e52f 100644 --- a/tests/adapters/test_openai_responses_adapter.py +++ b/tests/adapters/test_openai_responses_adapter.py @@ -22,7 +22,10 @@ def test_openai_responses_adapter_with_real_response_simple(snapshot: SnapshotAs assert len(eval_rows) == 1 # Convert to dict for snapshot testing - eval_rows_dict = [row.model_dump(exclude={"created_at", "execution_metadata"}) for row in eval_rows] + eval_rows_dict = [ + row.model_dump(exclude={"created_at": True, "execution_metadata": True, "messages": {"__all__": {"weight"}}}) + for row in eval_rows + ] # Assert against snapshot assert eval_rows_dict == snapshot @@ -42,7 +45,10 @@ def test_openai_responses_adapter_with_real_response_parallel_tool_calls(snapsho assert len(eval_rows) == 1 # Convert to dict for snapshot testing - eval_rows_dict = [row.model_dump(exclude={"created_at", "execution_metadata"}) for row in eval_rows] + eval_rows_dict = [ + row.model_dump(exclude={"created_at": True, "execution_metadata": True, "messages": {"__all__": {"weight"}}}) + for row in eval_rows + ] # Assert against snapshot assert eval_rows_dict == snapshot diff --git a/tests/test_models.py b/tests/test_models.py index 9e0f09f9..723685b8 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -694,3 +694,30 @@ def test_evaluation_row_extra_fields(): assert "eval" in dictionary assert "accuracy" in dictionary["eval_details"]["metrics"] assert "test" in dictionary["extra_fields"] + + +def test_message_with_weight_dump(): + example = { + "role": "user", + "content": "Hello, how are you?", + "weight": 0, + } + + message = Message(**example) + dictionary = message.model_dump() + assert "weight" in dictionary + assert dictionary["weight"] == 0 + + +def test_message_dump_for_chat_completion_request(): + example = { + "role": "user", + "content": "Hello, how are you?", + "weight": 0, + "reasoning_content": "I am thinking about the user's question", + } + message = Message(**example) + dictionary = message.dump_mdoel_for_chat_completion_request() + assert "weight" not in dictionary + assert "reasoning_content" not in dictionary + assert dictionary["content"] == "Hello, how are you?"