From f74481cd45c7cee8396d117eef103b9126ca5b31 Mon Sep 17 00:00:00 2001 From: Dylan Huang Date: Wed, 20 Aug 2025 10:14:45 -0700 Subject: [PATCH 01/11] save --- eval_protocol/mcp/execution/manager.py | 15 +- eval_protocol/models.py | 228 +++++++++-- eval_protocol/pytest/plugin.py | 2 +- eval_protocol/pytest/utils.py | 19 +- tests/test_migration_changes.py | 446 +++++++++++++++++++++ tests/test_retry_mechanism.py | 6 +- tests/test_status_migration_integration.py | 430 ++++++++++++++++++++ tests/test_status_model.py | 423 +++++++++++++++++++ vite-app/src/types/eval-protocol.ts | 16 +- 9 files changed, 1530 insertions(+), 55 deletions(-) create mode 100644 tests/test_migration_changes.py create mode 100644 tests/test_status_migration_integration.py create mode 100644 tests/test_status_model.py diff --git a/eval_protocol/mcp/execution/manager.py b/eval_protocol/mcp/execution/manager.py index 6b1163e9..8e461b49 100644 --- a/eval_protocol/mcp/execution/manager.py +++ b/eval_protocol/mcp/execution/manager.py @@ -20,7 +20,7 @@ from vendor.tau2.data_model.message import AssistantMessage, UserMessage from vendor.tau2.user.user_simulator import UserSimulator -from ...models import EvaluationRow, InputMetadata, Message, RolloutStatus +from ...models import EvaluationRow, InputMetadata, Message, Status from ...types import TerminationReason, Trajectory, NonSkippableException if TYPE_CHECKING: @@ -136,15 +136,14 @@ async def _execute_with_semaphore(idx): } if trajectory.terminated: - evaluation_row.rollout_status.termination_reason = trajectory.termination_reason - evaluation_row.rollout_status.status = RolloutStatus.Status.FINISHED - # preserve the true error mesage if there are any + extra_info = None if trajectory.control_plane_summary.get("error_message"): - evaluation_row.rollout_status.extra_info = { - "error_message": trajectory.control_plane_summary.get("error_message") - } + extra_info = {"error_message": trajectory.control_plane_summary.get("error_message")} + evaluation_row.rollout_status = Status.rollout_finished( + termination_reason=trajectory.termination_reason, extra_info=extra_info + ) else: - evaluation_row.rollout_status.status = RolloutStatus.Status.RUNNING + evaluation_row.rollout_status = Status.rollout_running() return evaluation_row diff --git a/eval_protocol/models.py b/eval_protocol/models.py index 83a0f178..d5464300 100644 --- a/eval_protocol/models.py +++ b/eval_protocol/models.py @@ -15,6 +15,199 @@ from eval_protocol.types import TerminationReason +class ErrorInfo(BaseModel): + """ + AIP-193 ErrorInfo model for structured error details. + + This model follows Google's AIP-193 standard for ErrorInfo: + https://google.aip.dev/193#errorinfo + + Attributes: + reason (str): A short snake_case description of the cause of the error. + domain (str): The logical grouping to which the reason belongs. + metadata (Dict[str, Any]): Additional dynamic information as context. + """ + + reason: str = Field(..., description="Short snake_case description of the error cause") + domain: str = Field(..., description="Logical grouping for the error reason") + metadata: Dict[str, Any] = Field(default_factory=dict, description="Additional dynamic information as context") + + def to_aip193_format(self) -> Dict[str, Any]: + """Convert to AIP-193 format with @type field.""" + return { + "@type": "type.googleapis.com/google.rpc.ErrorInfo", + "reason": self.reason, + "domain": self.domain, + "metadata": self.metadata, + } + + @classmethod + def termination_reason(cls, reason: str) -> "ErrorInfo": + """Create an ErrorInfo for termination reason.""" + return cls(reason="TERMINATION_REASON", domain="evalprotocol.io", metadata={"termination_reason": reason}) + + @classmethod + def extra_info(cls, metadata: Dict[str, Any]) -> "ErrorInfo": + """Create an ErrorInfo for extra information.""" + return cls(reason="EXTRA_INFO", domain="evalprotocol.io", metadata=metadata) + + @classmethod + def rollout_error(cls, metadata: Dict[str, Any]) -> "ErrorInfo": + """Create an ErrorInfo for rollout errors.""" + return cls(reason="ROLLOUT_ERROR", domain="evalprotocol.io", metadata=metadata) + + @classmethod + def stopped_reason(cls, reason: str) -> "ErrorInfo": + """Create an ErrorInfo for stopped reason.""" + return cls(reason="STOPPED", domain="evalprotocol.io", metadata={"reason": reason}) + + +class Status(BaseModel): + """ + AIP-193 compatible Status model for standardized error responses. + + This model follows Google's AIP-193 standard for error handling: + https://google.aip.dev/193 + + Attributes: + code (int): The status code, must be the numeric value of one of the elements + of google.rpc.Code enum (e.g., 5 for NOT_FOUND). + message (str): Developer-facing, human-readable debug message in English. + details (List[Dict[str, Any]]): Additional error information, each packed in + a google.protobuf.Any message format. + """ + + code: "Status.Code" = Field(..., description="The status code from google.rpc.Code enum") + message: str = Field(..., description="Developer-facing, human-readable debug message in English") + details: List[Dict[str, Any]] = Field( + default_factory=list, + description="Additional error information, each packed in a google.protobuf.Any message format", + ) + + # Convenience constants for common status codes + class Code(int, Enum): + """Common gRPC status codes as defined in google.rpc.Code""" + + OK = 0 + CANCELLED = 1 + UNKNOWN = 2 + INVALID_ARGUMENT = 3 + DEADLINE_EXCEEDED = 4 + NOT_FOUND = 5 + ALREADY_EXISTS = 6 + PERMISSION_DENIED = 7 + RESOURCE_EXHAUSTED = 8 + FAILED_PRECONDITION = 9 + ABORTED = 10 + OUT_OF_RANGE = 11 + UNIMPLEMENTED = 12 + INTERNAL = 13 + UNAVAILABLE = 14 + DATA_LOSS = 15 + UNAUTHENTICATED = 16 + + # Custom codes for rollout states (using higher numbers to avoid conflicts) + FINISHED = 100 # Custom code for rollout finished + + @classmethod + def rollout_running(cls) -> "Status": + """Create a status indicating the rollout is running.""" + return cls(code=cls.Code.OK, message="Rollout is running", details=[]) + + @classmethod + def rollout_finished( + cls, termination_reason: Optional[str] = None, extra_info: Optional[Dict[str, Any]] = None + ) -> "Status": + """Create a status indicating the rollout finished.""" + details = [] + if termination_reason: + details.append(ErrorInfo.termination_reason(termination_reason).to_aip193_format()) + if extra_info: + details.append(ErrorInfo.extra_info(extra_info).to_aip193_format()) + return cls(code=cls.Code.FINISHED, message="Rollout finished", details=details) + + @classmethod + def rollout_error(cls, error_message: str, extra_info: Optional[Dict[str, Any]] = None) -> "Status": + """Create a status indicating the rollout failed with an error.""" + details = [] + if extra_info: + details.append(ErrorInfo.rollout_error(extra_info).to_aip193_format()) + return cls.error(error_message, details) + + @classmethod + def error(cls, error_message: str, details: Optional[List[Dict[str, Any]]] = None) -> "Status": + """Create a status indicating the rollout failed with an error.""" + return cls(code=cls.Code.INTERNAL, message=error_message, details=details) + + @classmethod + def rollout_stopped(cls, reason: str = "Rollout stopped") -> "Status": + """Create a status indicating the rollout was stopped.""" + details = [ErrorInfo.stopped_reason(reason).to_aip193_format()] + return cls(code=cls.Code.CANCELLED, message=reason, details=details) + + @classmethod + def with_termination_reason(cls, termination_reason: str, extra_info: Optional[Dict[str, Any]] = None) -> "Status": + """Create a status indicating the rollout finished with termination reason.""" + details = [ErrorInfo.termination_reason(termination_reason).to_aip193_format()] + + if extra_info: + details.append(ErrorInfo.extra_info(extra_info).to_aip193_format()) + + return cls(code=cls.Code.FINISHED, message="Rollout finished", details=details) + + def is_running(self) -> bool: + """Check if the status indicates the rollout is running.""" + return self.code == self.Code.OK and self.message == "Rollout is running" + + def is_finished(self) -> bool: + """Check if the status indicates the rollout finished successfully.""" + return self.code == self.Code.FINISHED + + def is_error(self) -> bool: + """Check if the status indicates the rollout failed with an error.""" + return self.code == self.Code.INTERNAL + + def is_stopped(self) -> bool: + """Check if the status indicates the rollout was stopped.""" + return self.code == self.Code.CANCELLED + + def get_termination_reason(self) -> Optional[str]: + """Extract termination reason from details if present.""" + for detail in self.details: + if detail.get("@type") == "type.googleapis.com/google.rpc.ErrorInfo": + metadata = detail.get("metadata", {}) + if detail.get("reason") == "TERMINATION_REASON" and "termination_reason" in metadata: + return metadata["termination_reason"] + return None + + def get_extra_info(self) -> Optional[Dict[str, Any]]: + """Extract extra info from details if present.""" + for detail in self.details: + if detail.get("@type") == "type.googleapis.com/google.rpc.ErrorInfo": + metadata = detail.get("metadata", {}) + reason = detail.get("reason") + # Skip termination_reason and stopped details, return other error info + if reason not in ["TERMINATION_REASON", "STOPPED"]: + return metadata + return None + + def __hash__(self) -> int: + """Generate a hash for the Status object.""" + # Use a stable hash based on code, message, and details + import hashlib + + # Create a stable string representation + hash_data = f"{self.code}:{self.message}:{len(self.details)}" + + # Add details content for more uniqueness + for detail in sorted(self.details, key=lambda x: str(x)): + hash_data += f":{str(detail)}" + + # Generate hash + hash_obj = hashlib.sha256(hash_data.encode("utf-8")) + return int.from_bytes(hash_obj.digest()[:8], byteorder="big") + + class ChatCompletionContentPartTextParam(BaseModel): text: str = Field(..., description="The text content.") type: Literal["text"] = Field("text", description="The type of the content part.") @@ -289,27 +482,6 @@ class ExecutionMetadata(BaseModel): ) -class RolloutStatus(BaseModel): - """Status of the rollout.""" - - """ - running: Unfinished rollout which is still in progress. - finished: Rollout finished. - error: Rollout failed due to unexpected error. The rollout record should be discard. - """ - - class Status(str, Enum): - RUNNING = "running" - FINISHED = "finished" - ERROR = "error" - - status: Status = Field(Status.RUNNING, description="Status of the rollout.") - termination_reason: Optional[TerminationReason] = Field( - None, description="reason of the rollout status, mapped to values in TerminationReason" - ) - extra_info: Optional[Dict[str, Any]] = Field(None, description="Extra information about the rollout status.") - - class EvaluationRow(BaseModel): """ Unified data structure for a single evaluation unit that contains messages, @@ -334,9 +506,9 @@ class EvaluationRow(BaseModel): description="Metadata related to the input (dataset info, model config, session data, etc.).", ) - rollout_status: RolloutStatus = Field( - default_factory=RolloutStatus, - description="The status of the rollout.", + rollout_status: Status = Field( + default_factory=Status.rollout_running, + description="The status of the rollout following AIP-193 standards.", ) # Ground truth reference (moved from EvaluateResult to top level) @@ -381,6 +553,14 @@ def is_trajectory_evaluation(self) -> bool: and len(self.evaluation_result.step_outputs) > 0 ) + def get_rollout_status(self) -> Status: + """Get the rollout status (backwards compatibility method).""" + return self.rollout_status + + def set_rollout_status(self, status: Status) -> None: + """Set the rollout status (backwards compatibility method).""" + self.rollout_status = status + def get_conversation_length(self) -> int: """Returns the number of messages in the conversation.""" return len(self.messages) diff --git a/eval_protocol/pytest/plugin.py b/eval_protocol/pytest/plugin.py index 81b36420..460eeb14 100644 --- a/eval_protocol/pytest/plugin.py +++ b/eval_protocol/pytest/plugin.py @@ -64,7 +64,7 @@ def pytest_addoption(parser) -> None: action="store", type=int, default=0, - help=("Failed rollouts (with rollout_status.status == 'error') will be retried up to this many times."), + help=("Failed rollouts (with rollout_status.code indicating error) will be retried up to this many times."), ) group.addoption( "--ep-fail-on-max-retry", diff --git a/eval_protocol/pytest/utils.py b/eval_protocol/pytest/utils.py index e0b8328a..96ac2add 100644 --- a/eval_protocol/pytest/utils.py +++ b/eval_protocol/pytest/utils.py @@ -6,7 +6,7 @@ from typing import Any, Callable, Dict, List, Literal, Optional, Union from eval_protocol.dataset_logger.dataset_logger import DatasetLogger -from eval_protocol.models import EvalMetadata, EvaluationRow, RolloutStatus +from eval_protocol.models import EvalMetadata, EvaluationRow, Status from eval_protocol.pytest.rollout_processor import RolloutProcessor from eval_protocol.pytest.types import ( CompletionParams, @@ -257,7 +257,7 @@ async def retry_handler(failed_row: EvaluationRow): current_attempts = retry_counts.get(rollout_id, 0) if current_attempts >= max_retry: - assert failed_row.rollout_status and failed_row.rollout_status.status == RolloutStatus.Status.ERROR, ( + assert failed_row.rollout_status and failed_row.rollout_status.is_error(), ( f"Rollout {failed_row.execution_metadata.rollout_id} did not fail with error status" ) failed_permanently.append(failed_row) @@ -273,11 +273,10 @@ async def retry_handler(failed_row: EvaluationRow): try: retry_result = await retry_tasks[0] - retry_result.rollout_status.status = RolloutStatus.Status.FINISHED + retry_result.rollout_status = Status.rollout_finished() await queue.put(retry_result) except Exception as e: - failed_row.rollout_status.status = RolloutStatus.Status.ERROR - failed_row.rollout_status.termination_reason = str(e) + failed_row.rollout_status = Status.rollout_error(str(e)) asyncio.create_task(retry_handler(failed_row)) # retry failed, spawn another retry async def initial_processor(): @@ -299,12 +298,11 @@ async def initial_processor(): try: result = await task - result.rollout_status.status = RolloutStatus.Status.FINISHED + result.rollout_status = Status.rollout_finished() await queue.put(result) except Exception as e: failed_row = fresh_dataset[task_index] - failed_row.rollout_status.status = RolloutStatus.Status.ERROR - failed_row.rollout_status.termination_reason = str(e) + failed_row.rollout_status = Status.rollout_error(str(e)) asyncio.create_task(retry_handler(failed_row)) # rollout errored, spawn retry task processor_task = asyncio.create_task(initial_processor()) @@ -317,10 +315,11 @@ async def initial_processor(): finished_row = await queue.get() # only permanent failure rows are put on the queue, so we can check for them here - if finished_row.rollout_status and finished_row.rollout_status.status == RolloutStatus.Status.ERROR: + if finished_row.rollout_status and finished_row.rollout_status.is_error(): if max_retry > 0 and os.getenv("EP_FAIL_ON_MAX_RETRY", "true") != "false": + termination_reason = finished_row.rollout_status.get_termination_reason() or "Unknown error" raise RuntimeError( - f"Rollout {finished_row.execution_metadata.rollout_id} failed after {max_retry} retries. Errors: {finished_row.rollout_status.termination_reason}" + f"Rollout {finished_row.execution_metadata.rollout_id} failed after {max_retry} retries. Errors: {termination_reason}" ) completed_count += 1 diff --git a/tests/test_migration_changes.py b/tests/test_migration_changes.py new file mode 100644 index 00000000..fed33e0b --- /dev/null +++ b/tests/test_migration_changes.py @@ -0,0 +1,446 @@ +#!/usr/bin/env python3 +""" +Tests for the migration changes in the existing codebase. + +This test suite verifies that: +- All migrated code works correctly with the new Status model +- The field name remains as 'rollout_status' +- All helper methods work as expected +- AIP-193 compliance is maintained +""" + +import pytest +from unittest.mock import Mock, patch, MagicMock +from eval_protocol.models import Status, EvaluationRow, Message +from eval_protocol.types import TerminationReason + + +class TestMCPExecutionManagerMigration: + """Test the migration changes in MCP execution manager.""" + + def test_trajectory_terminated_status_creation(self): + """Test that terminated trajectory creates correct status.""" + # Mock trajectory with termination + trajectory = Mock() + trajectory.terminated = True + trajectory.termination_reason = "goal_reached" + trajectory.control_plane_summary = {"error_message": "No errors"} + + # Create evaluation row + row = EvaluationRow(messages=[]) + + # Simulate the status assignment from MCP execution manager + extra_info = {"error_message": trajectory.control_plane_summary.get("error_message")} + + row.rollout_status = Status( + code=Status.Code.FINISHED, + message="Rollout finished", + details=[ + { + "@type": "type.googleapis.com/google.rpc.ErrorInfo", + "reason": "TERMINATION_REASON", + "domain": "evalprotocol.io", + "metadata": {"termination_reason": trajectory.termination_reason}, + } + ] + + ( + [ + { + "@type": "type.googleapis.com/google.rpc.ErrorInfo", + "reason": "EXTRA_INFO", + "domain": "evalprotocol.io", + "metadata": extra_info, + } + ] + if extra_info + else [] + ), + ) + + # Verify the status + assert row.rollout_status.code == Status.Code.FINISHED + assert row.rollout_status.message == "Rollout finished" + assert row.rollout_status.is_finished() + + # Verify termination reason + assert row.rollout_status.get_termination_reason() == "goal_reached" + + # Verify extra info + assert row.rollout_status.get_extra_info() == {"error_message": "No errors"} + + # Verify details structure + assert len(row.rollout_status.details) == 2 + assert row.rollout_status.details[0]["reason"] == "TERMINATION_REASON" + assert row.rollout_status.details[1]["reason"] == "EXTRA_INFO" + + def test_trajectory_running_status_creation(self): + """Test that running trajectory creates correct status.""" + # Mock trajectory that's still running + trajectory = Mock() + trajectory.terminated = False + + # Create evaluation row + row = EvaluationRow(messages=[]) + + # Simulate the status assignment from MCP execution manager + row.rollout_status = Status(code=Status.Code.OK, message="Rollout is running", details=[]) + + # Verify the status + assert row.rollout_status.code == Status.Code.OK + assert row.rollout_status.message == "Rollout is running" + assert row.rollout_status.is_running() + assert not row.rollout_status.is_finished() + assert not row.rollout_status.is_error() + assert not row.rollout_status.is_stopped() + + def test_trajectory_terminated_without_error_message(self): + """Test terminated trajectory without error message.""" + # Mock trajectory with termination but no error + trajectory = Mock() + trajectory.terminated = True + trajectory.termination_reason = "timeout" + trajectory.control_plane_summary = {} + + # Create evaluation row + row = EvaluationRow(messages=[]) + + # Simulate the status assignment + extra_info = None + if trajectory.control_plane_summary.get("error_message"): + extra_info = {"error_message": trajectory.control_plane_summary.get("error_message")} + + row.rollout_status = Status( + code=Status.Code.FINISHED, + message="Rollout finished", + details=[ + { + "@type": "type.googleapis.com/google.rpc.ErrorInfo", + "reason": "TERMINATION_REASON", + "domain": "evalprotocol.io", + "metadata": {"termination_reason": trajectory.termination_reason}, + } + ] + + ( + [ + { + "@type": "type.googleapis.com/google.rpc.ErrorInfo", + "reason": "EXTRA_INFO", + "domain": "evalprotocol.io", + "metadata": extra_info, + } + ] + if extra_info + else [] + ), + ) + + # Verify the status + assert row.rollout_status.code == Status.Code.FINISHED + assert row.rollout_status.is_finished() + assert row.rollout_status.get_termination_reason() == "timeout" + + # Should not have extra info since there was no error message + assert row.rollout_status.get_extra_info() is None + + # Should only have termination reason detail + assert len(row.rollout_status.details) == 1 + assert row.rollout_status.details[0]["reason"] == "TERMINATION_REASON" + + +class TestPytestUtilsMigration: + """Test the migration changes in pytest utils.""" + + def test_retry_success_status_update(self): + """Test that retry success updates status correctly.""" + row = EvaluationRow(messages=[]) + + # Simulate the status update from pytest utils + row.rollout_status = Status(code=Status.Code.FINISHED, message="Rollout finished successfully", details=[]) + + # Verify the status + assert row.rollout_status.code == Status.Code.FINISHED + assert row.rollout_status.message == "Rollout finished successfully" + assert row.rollout_status.is_finished() + assert not row.rollout_status.is_running() + + def test_retry_failure_status_update(self): + """Test that retry failure updates status correctly.""" + row = EvaluationRow(messages=[]) + + # Simulate the status update from pytest utils + row.rollout_status = Status( + code=Status.Code.INTERNAL, + message="Test error message", + details=[ + { + "@type": "type.googleapis.com/google.rpc.ErrorInfo", + "reason": "ROLLOUT_ERROR", + "domain": "evalprotocol.io", + "metadata": {}, + } + ], + ) + + # Verify the status + assert row.rollout_status.code == Status.Code.INTERNAL + assert row.rollout_status.message == "Test error message" + assert row.rollout_status.is_error() + assert not row.rollout_status.is_finished() + + def test_initial_processor_success_status_update(self): + """Test that initial processor success updates status correctly.""" + row = EvaluationRow(messages=[]) + + # Simulate the status update from pytest utils + row.rollout_status = Status(code=Status.Code.FINISHED, message="Rollout finished successfully", details=[]) + + # Verify the status + assert row.rollout_status.code == Status.Code.FINISHED + assert row.rollout_status.message == "Rollout finished successfully" + assert row.rollout_status.is_finished() + + def test_initial_processor_failure_status_update(self): + """Test that initial processor failure updates status correctly.""" + row = EvaluationRow(messages=[]) + + # Simulate the status update from pytest utils + row.rollout_status = Status( + code=Status.Code.INTERNAL, + message="Runtime error occurred", + details=[ + { + "@type": "type.googleapis.com/google.rpc.ErrorInfo", + "reason": "ROLLOUT_ERROR", + "domain": "evalprotocol.io", + "metadata": {}, + } + ], + ) + + # Verify the status + assert row.rollout_status.code == Status.Code.INTERNAL + assert row.rollout_status.message == "Runtime error occurred" + assert row.rollout_status.is_error() + + def test_error_status_checking(self): + """Test that error status checking works correctly.""" + row = EvaluationRow(messages=[]) + + # Set error status + row.rollout_status = Status.rollout_error("Test error") + + # Should be detected as error + assert row.rollout_status.is_error() + + # Should be able to get termination reason (None for error status) + assert row.rollout_status.get_termination_reason() is None + + +class TestTestRetryMechanismMigration: + """Test the migration changes in test_retry_mechanism.py.""" + + def test_success_failure_detection(self): + """Test that success/failure detection works correctly.""" + row = EvaluationRow(messages=[]) + + # Test success case + row.rollout_status = Status.rollout_finished() + success = row.rollout_status.is_finished() + assert success is True + + # Test failure case + row.rollout_status = Status.rollout_error("Test error") + success = row.rollout_status.is_finished() + assert success is False + + def test_score_assignment_based_on_status(self): + """Test that score assignment works based on status.""" + row = EvaluationRow(messages=[]) + + # Test success score + row.rollout_status = Status.rollout_finished() + score = 1.0 if row.rollout_status.is_finished() else 0.0 + assert score == 1.0 + + # Test failure score + row.rollout_status = Status.rollout_error("Test error") + score = 1.0 if row.rollout_status.is_finished() else 0.0 + assert score == 0.0 + + +class TestStatusModelIntegration: + """Test integration of Status model with existing functionality.""" + + def test_status_creation_methods_integration(self): + """Test that all status creation methods work together.""" + row = EvaluationRow(messages=[]) + + # Test running status + row.rollout_status = Status.rollout_running() + assert row.rollout_status.is_running() + assert row.rollout_status.code == Status.Code.OK + + # Test finished status + row.rollout_status = Status.rollout_finished() + assert row.rollout_status.is_finished() + assert row.rollout_status.code == Status.Code.FINISHED + + # Test error status + row.rollout_status = Status.rollout_error("Test error") + assert row.rollout_status.is_error() + assert row.rollout_status.code == Status.Code.INTERNAL + + # Test stopped status + row.rollout_status = Status.rollout_stopped("User stop") + assert row.rollout_status.is_stopped() + assert row.rollout_status.code == Status.Code.CANCELLED + + def test_termination_reason_integration(self): + """Test integration of termination reason with status.""" + row = EvaluationRow(messages=[]) + + # Test with termination reason + termination_status = Status.with_termination_reason("goal_reached") + row.rollout_status = termination_status + + assert row.rollout_status.is_finished() + assert row.rollout_status.get_termination_reason() == "goal_reached" + + # Test with termination reason and extra info + extra_info = {"steps": 10, "reward": 0.8} + termination_status_with_info = Status.with_termination_reason("timeout", extra_info) + row.rollout_status = termination_status_with_info + + assert row.rollout_status.is_finished() + assert row.rollout_status.get_termination_reason() == "timeout" + assert row.rollout_status.get_extra_info() == extra_info + + def test_error_handling_integration(self): + """Test integration of error handling with status.""" + row = EvaluationRow(messages=[]) + + # Test error with metadata + error_info = {"error_code": "E001", "line": 42} + error_status = Status.rollout_error("Runtime error", error_info) + row.rollout_status = error_status + + assert row.rollout_status.is_error() + assert row.rollout_status.get_extra_info() == error_info + assert row.rollout_status.get_termination_reason() is None + + # Test error without metadata + simple_error_status = Status.rollout_error("Simple error") + row.rollout_status = simple_error_status + + assert row.rollout_status.is_error() + assert row.rollout_status.get_extra_info() is None + + def test_status_transitions_integration(self): + """Test that status transitions work correctly in integration.""" + row = EvaluationRow(messages=[]) + + # Start with running + row.rollout_status = Status.rollout_running() + assert row.rollout_status.is_running() + + # Transition to finished + row.rollout_status = Status.rollout_finished() + assert row.rollout_status.is_finished() + assert not row.rollout_status.is_running() + + # Transition to error + row.rollout_status = Status.rollout_error("Something went wrong") + assert row.rollout_status.is_error() + assert not row.rollout_status.is_finished() + + # Transition back to finished + row.rollout_status = Status.rollout_finished() + assert row.rollout_status.is_finished() + assert not row.rollout_status.is_error() + + +class TestAIP193Compliance: + """Test AIP-193 compliance in the migrated code.""" + + def test_error_info_structure_compliance(self): + """Test that ErrorInfo structure follows AIP-193.""" + row = EvaluationRow(messages=[]) + + # Create error status with metadata + error_info = {"error_code": "E001", "timestamp": "2024-01-01"} + error_status = Status.rollout_error("Test error", error_info) + row.rollout_status = error_status + + # Check AIP-193 ErrorInfo structure + assert len(row.rollout_status.details) == 1 + detail = row.rollout_status.details[0] + + # Required fields according to AIP-193 + assert detail["@type"] == "type.googleapis.com/google.rpc.ErrorInfo" + assert "reason" in detail + assert "domain" in detail + assert "metadata" in detail + + # Domain should be service-specific + assert detail["domain"] == "evalprotocol.io" + + # Metadata should contain the error info + assert detail["metadata"] == error_info + + def test_termination_reason_structure_compliance(self): + """Test that termination reason structure follows AIP-193.""" + row = EvaluationRow(messages=[]) + + # Create status with termination reason + termination_status = Status.with_termination_reason("goal_reached") + row.rollout_status = termination_status + + # Check AIP-193 structure + assert len(row.rollout_status.details) == 1 + detail = row.rollout_status.details[0] + + assert detail["@type"] == "type.googleapis.com/google.rpc.ErrorInfo" + assert detail["reason"] == "TERMINATION_REASON" + assert detail["domain"] == "evalprotocol.io" + assert "metadata" in detail + assert "termination_reason" in detail["metadata"] + + def test_multiple_details_compliance(self): + """Test that multiple details follow AIP-193 structure.""" + row = EvaluationRow(messages=[]) + + # Create status with both termination reason and extra info + extra_info = {"steps": 15, "reward": 0.9} + status = Status.with_termination_reason("goal_reached", extra_info) + row.rollout_status = status + + # Should have two details + assert len(row.rollout_status.details) == 2 + + # Both should follow ErrorInfo structure + for detail in row.rollout_status.details: + assert detail["@type"] == "type.googleapis.com/google.rpc.ErrorInfo" + assert "reason" in detail + assert "domain" in detail + assert "metadata" in detail + assert detail["domain"] == "evalprotocol.io" + + def test_status_code_compliance(self): + """Test that status codes follow gRPC standard.""" + row = EvaluationRow(messages=[]) + + # Test standard gRPC codes + statuses = [ + (Status.rollout_running(), Status.Code.OK), + (Status.rollout_finished(), Status.Code.FINISHED), # Custom code + (Status.rollout_error("Test"), Status.Code.INTERNAL), + (Status.rollout_stopped("Test"), Status.Code.CANCELLED), + ] + + for status, expected_code in statuses: + row.rollout_status = status + assert row.rollout_status.code == expected_code + + +if __name__ == "__main__": + pytest.main([__file__]) diff --git a/tests/test_retry_mechanism.py b/tests/test_retry_mechanism.py index 8b55869f..43681903 100644 --- a/tests/test_retry_mechanism.py +++ b/tests/test_retry_mechanism.py @@ -11,7 +11,7 @@ import pytest -from eval_protocol.models import EvaluateResult, EvaluationRow, Message, RolloutStatus +from eval_protocol.models import EvaluateResult, EvaluationRow, Message, Status from eval_protocol.pytest.evaluation_test import evaluation_test from eval_protocol.pytest.rollout_processor import RolloutProcessor from eval_protocol.pytest.types import RolloutProcessorConfig @@ -93,11 +93,11 @@ async def process_single_row( def test_retry_mechanism(row: EvaluationRow) -> EvaluationRow: """MOCK TEST: Tests that retry mechanism works - one task fails on first attempt, succeeds on retry.""" print( - f"šŸ“Š EVALUATED: {row.execution_metadata.rollout_id} ({'SUCCESS' if row.rollout_status.status == 'finished' else 'FAILURE'})" + f"šŸ“Š EVALUATED: {row.execution_metadata.rollout_id} ({'SUCCESS' if row.rollout_status.is_finished() else 'FAILURE'})" ) # Assign a score based on success/failure - score = 1.0 if row.rollout_status.status == "finished" else 0.0 + score = 1.0 if row.rollout_status.is_finished() else 0.0 row.evaluation_result = EvaluateResult(score=score) return row diff --git a/tests/test_status_migration_integration.py b/tests/test_status_migration_integration.py new file mode 100644 index 00000000..12291fdd --- /dev/null +++ b/tests/test_status_migration_integration.py @@ -0,0 +1,430 @@ +#!/usr/bin/env python3 +""" +Integration tests for the Status model migration. + +This test suite verifies that: +- All migrated code works correctly with the new Status model +- The field name remains as 'rollout_status' +- All helper methods work as expected +- AIP-193 compliance is maintained +""" + +import pytest +from unittest.mock import Mock, patch +from eval_protocol.models import Status, EvaluationRow, Message +from eval_protocol.types import TerminationReason + + +class TestStatusFieldNamePreservation: + """Test that the field name remains as 'rollout_status'.""" + + def test_evaluation_row_has_rollout_status_field(self): + """Test that EvaluationRow still has rollout_status field.""" + row = EvaluationRow(messages=[]) + + # Should have rollout_status field + assert hasattr(row, "rollout_status") + assert not hasattr(row, "status") + + # Field should be of type Status + assert isinstance(row.rollout_status, Status) + + def test_rollout_status_field_access(self): + """Test direct access to rollout_status field.""" + row = EvaluationRow(messages=[]) + + # Should be able to access directly + assert row.rollout_status.code == Status.Code.OK + assert row.rollout_status.message == "Rollout is running" + + # Should be able to set directly + row.rollout_status = Status.rollout_finished() + assert row.rollout_status.code == Status.Code.FINISHED + + +class TestBackwardsCompatibilityMethods: + """Test the backwards compatibility methods.""" + + def test_get_rollout_status_method(self): + """Test the get_rollout_status method.""" + row = EvaluationRow(messages=[]) + + # Method should return the current rollout_status + status = row.get_rollout_status() + assert status.code == Status.Code.OK + assert status.message == "Rollout is running" + + # Should be the same object reference + assert status is row.rollout_status + + def test_set_rollout_status_method(self): + """Test the set_rollout_status method.""" + row = EvaluationRow(messages=[]) + + # Method should update the rollout_status + new_status = Status.rollout_error("Test error") + row.set_rollout_status(new_status) + + assert row.rollout_status.code == Status.Code.INTERNAL + assert row.rollout_status.message == "Test error" + assert row.rollout_status is new_status + + +class TestStatusTransitions: + """Test transitioning between different status states.""" + + def test_running_to_finished_transition(self): + """Test transition from running to finished.""" + row = EvaluationRow(messages=[]) + + # Start with running + assert row.rollout_status.is_running() + assert not row.rollout_status.is_finished() + + # Transition to finished + row.rollout_status = Status.rollout_finished() + assert not row.rollout_status.is_running() + assert row.rollout_status.is_finished() + + def test_running_to_error_transition(self): + """Test transition from running to error.""" + row = EvaluationRow(messages=[]) + + # Start with running + assert row.rollout_status.is_running() + assert not row.rollout_status.is_error() + + # Transition to error + row.rollout_status = Status.rollout_error("Something went wrong") + assert not row.rollout_status.is_running() + assert row.rollout_status.is_error() + + def test_running_to_stopped_transition(self): + """Test transition from running to stopped.""" + row = EvaluationRow(messages=[]) + + # Start with running + assert row.rollout_status.is_running() + assert not row.rollout_status.is_stopped() + + # Transition to stopped + row.rollout_status = Status.rollout_stopped("User requested stop") + assert not row.rollout_status.is_running() + assert row.rollout_status.is_stopped() + + def test_error_to_finished_transition(self): + """Test transition from error to finished.""" + row = EvaluationRow(messages=[]) + + # Start with error + row.rollout_status = Status.rollout_error("Initial error") + assert row.rollout_status.is_error() + + # Transition to finished + row.rollout_status = Status.rollout_finished() + assert not row.rollout_status.is_error() + assert row.rollout_status.is_finished() + + +class TestTerminationReasonIntegration: + """Test integration of termination reason with the new Status model.""" + + def test_termination_reason_in_status_details(self): + """Test that termination reason is properly stored in status details.""" + row = EvaluationRow(messages=[]) + + # Set status with termination reason + termination_status = Status.with_termination_reason("goal_reached") + row.rollout_status = termination_status + + # Should be finished + assert row.rollout_status.is_finished() + + # Should have termination reason in details + assert row.rollout_status.get_termination_reason() == "goal_reached" + + # Check details structure + assert len(row.rollout_status.details) == 1 + detail = row.rollout_status.details[0] + assert detail["@type"] == "type.googleapis.com/google.rpc.ErrorInfo" + assert detail["reason"] == "TERMINATION_REASON" + assert detail["domain"] == "evalprotocol.io" + assert detail["metadata"]["termination_reason"] == "goal_reached" + + def test_termination_reason_with_extra_info(self): + """Test termination reason with additional extra info.""" + row = EvaluationRow(messages=[]) + + extra_info = {"steps": 10, "reward": 0.8} + termination_status = Status.with_termination_reason("timeout", extra_info) + row.rollout_status = termination_status + + # Should have both termination reason and extra info + assert row.rollout_status.get_termination_reason() == "timeout" + assert row.rollout_status.get_extra_info() == extra_info + + # Check details structure + assert len(row.rollout_status.details) == 2 + + # First detail should be termination reason + term_detail = row.rollout_status.details[0] + assert term_detail["reason"] == "TERMINATION_REASON" + + # Second detail should be extra info + extra_detail = row.rollout_status.details[1] + assert extra_detail["reason"] == "EXTRA_INFO" + assert extra_detail["metadata"] == extra_info + + def test_multiple_termination_reasons(self): + """Test handling of multiple termination reasons (edge case).""" + row = EvaluationRow(messages=[]) + + # Create status with duplicate termination reason details + details = [ + { + "@type": "type.googleapis.com/google.rpc.ErrorInfo", + "reason": "TERMINATION_REASON", + "domain": "evalprotocol.io", + "metadata": {"termination_reason": "first"}, + }, + { + "@type": "type.googleapis.com/google.rpc.ErrorInfo", + "reason": "TERMINATION_REASON", + "domain": "evalprotocol.io", + "metadata": {"termination_reason": "second"}, + }, + ] + + status = Status(code=Status.Code.FINISHED, message="Test", details=details) + row.rollout_status = status + + # Should return the first termination reason found + assert row.rollout_status.get_termination_reason() == "first" + + +class TestErrorHandlingIntegration: + """Test error handling integration with the new Status model.""" + + def test_error_status_with_metadata(self): + """Test error status with structured metadata.""" + row = EvaluationRow(messages=[]) + + error_info = { + "error_code": "E001", + "line": 42, + "function": "test_function", + "timestamp": "2024-01-01T12:00:00Z", + } + + error_status = Status.rollout_error("Runtime error occurred", error_info) + row.rollout_status = error_status + + # Should be error + assert row.rollout_status.is_error() + + # Should have error details + assert row.rollout_status.get_extra_info() == error_info + + # Should not have termination reason + assert row.rollout_status.get_termination_reason() is None + + # Check details structure + assert len(row.rollout_status.details) == 1 + detail = row.rollout_status.details[0] + assert detail["@type"] == "type.googleapis.com/google.rpc.ErrorInfo" + assert detail["reason"] == "ROLLOUT_ERROR" + assert detail["domain"] == "evalprotocol.io" + assert detail["metadata"] == error_info + + def test_error_status_without_metadata(self): + """Test error status without additional metadata.""" + row = EvaluationRow(messages=[]) + + error_status = Status.rollout_error("Simple error message") + row.rollout_status = error_status + + # Should be error + assert row.rollout_status.is_error() + + # Should not have extra info + assert row.rollout_status.get_extra_info() is None + + # Should not have termination reason + assert row.rollout_status.get_termination_reason() is None + + # Should have empty details + assert row.rollout_status.details == [] + + +class TestAIP193Compliance: + """Test AIP-193 compliance features.""" + + def test_error_info_structure(self): + """Test that ErrorInfo follows AIP-193 structure.""" + row = EvaluationRow(messages=[]) + + # Create status with error info + error_info = {"error_code": "E001"} + error_status = Status.rollout_error("Test error", error_info) + row.rollout_status = error_status + + # Check AIP-193 ErrorInfo structure + assert len(row.rollout_status.details) == 1 + detail = row.rollout_status.details[0] + + # Required fields + assert detail["@type"] == "type.googleapis.com/google.rpc.ErrorInfo" + assert "reason" in detail + assert "domain" in detail + assert "metadata" in detail + + # Domain should be service-specific + assert detail["domain"] == "evalprotocol.io" + + # Metadata should contain the error info + assert detail["metadata"] == error_info + + def test_multiple_detail_types(self): + """Test that multiple detail types can coexist.""" + row = EvaluationRow(messages=[]) + + # Create status with both termination reason and extra info + extra_info = {"steps": 15, "reward": 0.9} + status = Status.with_termination_reason("goal_reached", extra_info) + row.rollout_status = status + + # Should have two details + assert len(row.rollout_status.details) == 2 + + # First detail should be termination reason + term_detail = row.rollout_status.details[0] + assert term_detail["reason"] == "TERMINATION_REASON" + + # Second detail should be extra info + extra_detail = row.rollout_status.details[1] + assert extra_detail["reason"] == "EXTRA_INFO" + + # Both should follow ErrorInfo structure + for detail in row.rollout_status.details: + assert detail["@type"] == "type.googleapis.com/google.rpc.ErrorInfo" + assert "reason" in detail + assert "domain" in detail + assert "metadata" in detail + + def test_status_code_mapping(self): + """Test that status codes map correctly to gRPC codes.""" + row = EvaluationRow(messages=[]) + + # Test different status types and their codes + statuses = [ + (Status.rollout_running(), Status.Code.OK), + (Status.rollout_finished(), Status.Code.FINISHED), + (Status.rollout_error("Test"), Status.Code.INTERNAL), + (Status.rollout_stopped("Test"), Status.Code.CANCELLED), + ] + + for status, expected_code in statuses: + row.rollout_status = status + assert row.rollout_status.code == expected_code + + +class TestSerializationAndDeserialization: + """Test that Status can be properly serialized and deserialized.""" + + def test_status_model_dump(self): + """Test that Status can be dumped to dict.""" + row = EvaluationRow(messages=[]) + + # Set a complex status + extra_info = {"steps": 10, "reward": 0.8} + termination_status = Status.with_termination_reason("goal_reached", extra_info) + row.rollout_status = termination_status + + # Dump to dict + status_dict = row.rollout_status.model_dump() + + # Check structure + assert "code" in status_dict + assert "message" in status_dict + assert "details" in status_dict + + # Check values + assert status_dict["code"] == Status.Code.FINISHED + assert status_dict["message"] == "Rollout finished" + assert len(status_dict["details"]) == 2 + + def test_status_model_validate(self): + """Test that Status can be reconstructed from dict.""" + row = EvaluationRow(messages=[]) + + # Set a complex status + extra_info = {"steps": 10, "reward": 0.8} + original_status = Status.with_termination_reason("goal_reached", extra_info) + row.rollout_status = original_status + + # Dump and reconstruct + status_dict = row.rollout_status.model_dump() + reconstructed_status = Status.model_validate(status_dict) + + # Should be equivalent + assert reconstructed_status.code == original_status.code + assert reconstructed_status.message == original_status.message + assert len(reconstructed_status.details) == len(original_status.details) + + # Should preserve functionality + assert reconstructed_status.get_termination_reason() == "goal_reached" + assert reconstructed_status.get_extra_info() == extra_info + + +class TestEdgeCases: + """Test edge cases and error conditions.""" + + def test_empty_status_details(self): + """Test Status with empty details.""" + row = EvaluationRow(messages=[]) + + # Create status with empty details + empty_status = Status(code=Status.Code.OK, message="Test", details=[]) + row.rollout_status = empty_status + + # Should handle gracefully + assert row.rollout_status.get_termination_reason() is None + assert row.rollout_status.get_extra_info() is None + + def test_malformed_status_details(self): + """Test Status with malformed details.""" + row = EvaluationRow(messages=[]) + + # Create status with malformed details + malformed_details = [ + {"not_type": "invalid", "reason": "TEST"}, + {"@type": "type.googleapis.com/google.rpc.ErrorInfo", "metadata": {"termination_reason": "test"}}, + ] + + malformed_status = Status(code=Status.Code.OK, message="Test", details=malformed_details) + row.rollout_status = malformed_status + + # Should handle gracefully + assert row.rollout_status.get_termination_reason() == "test" + assert row.rollout_status.get_extra_info() is None + + def test_large_metadata_handling(self): + """Test Status with large metadata.""" + row = EvaluationRow(messages=[]) + + # Create large metadata + large_metadata = {f"key_{i}": f"value_{i}" for i in range(100)} + + # Should handle large metadata + large_status = Status.rollout_error("Test error", large_metadata) + row.rollout_status = large_status + + # Should preserve all metadata + extra_info = row.rollout_status.get_extra_info() + assert extra_info == large_metadata + assert extra_info is not None + assert len(extra_info) == 100 + + +if __name__ == "__main__": + pytest.main([__file__]) diff --git a/tests/test_status_model.py b/tests/test_status_model.py new file mode 100644 index 00000000..cdc5512b --- /dev/null +++ b/tests/test_status_model.py @@ -0,0 +1,423 @@ +#!/usr/bin/env python3 +""" +Tests for the AIP-193 compatible Status model. + +This test suite covers: +- Status code enum values +- Status creation methods +- Helper methods for checking status types +- AIP-193 compliance features +- Migration from RolloutStatus +""" + +import pytest +from eval_protocol.models import Status, EvaluationRow, Message, ErrorInfo + + +class TestErrorInfoModel: + """Test the ErrorInfo model.""" + + def test_error_info_creation(self): + """Test creating ErrorInfo instances.""" + error_info = ErrorInfo( + reason="TEST_ERROR", domain="evalprotocol.io", metadata={"error_code": "E001", "line": 42} + ) + + assert error_info.reason == "TEST_ERROR" + assert error_info.domain == "evalprotocol.io" + assert error_info.metadata == {"error_code": "E001", "line": 42} + + def test_error_info_to_aip193_format(self): + """Test conversion to AIP-193 format.""" + error_info = ErrorInfo(reason="TEST_ERROR", domain="evalprotocol.io", metadata={"error_code": "E001"}) + + aip193_format = error_info.to_aip193_format() + + assert aip193_format["@type"] == "type.googleapis.com/google.rpc.ErrorInfo" + assert aip193_format["reason"] == "TEST_ERROR" + assert aip193_format["domain"] == "evalprotocol.io" + assert aip193_format["metadata"] == {"error_code": "E001"} + + def test_error_info_factory_methods(self): + """Test the factory methods for common error types.""" + # Test termination_reason + term_error = ErrorInfo.termination_reason("goal_reached") + assert term_error.reason == "TERMINATION_REASON" + assert term_error.domain == "evalprotocol.io" + assert term_error.metadata["termination_reason"] == "goal_reached" + + # Test extra_info + extra_error = ErrorInfo.extra_info({"steps": 10, "reward": 0.8}) + assert extra_error.reason == "EXTRA_INFO" + assert extra_error.domain == "evalprotocol.io" + assert extra_error.metadata == {"steps": 10, "reward": 0.8} + + # Test rollout_error + rollout_error = ErrorInfo.rollout_error({"error_code": "E001"}) + assert rollout_error.reason == "ROLLOUT_ERROR" + assert rollout_error.domain == "evalprotocol.io" + assert rollout_error.metadata == {"error_code": "E001"} + + # Test stopped_reason + stopped_error = ErrorInfo.stopped_reason("user_request") + assert stopped_error.reason == "STOPPED" + assert stopped_error.domain == "evalprotocol.io" + assert stopped_error.metadata["reason"] == "user_request" + + +class TestStatusModel: + """Test the AIP-193 compatible Status model.""" + + def test_status_code_enum_values(self): + """Test that Status.Code enum has the correct values.""" + assert Status.Code.OK == 0 + assert Status.Code.CANCELLED == 1 + assert Status.Code.UNKNOWN == 2 + assert Status.Code.INVALID_ARGUMENT == 3 + assert Status.Code.DEADLINE_EXCEEDED == 4 + assert Status.Code.NOT_FOUND == 5 + assert Status.Code.ALREADY_EXISTS == 6 + assert Status.Code.PERMISSION_DENIED == 7 + assert Status.Code.RESOURCE_EXHAUSTED == 8 + assert Status.Code.FAILED_PRECONDITION == 9 + assert Status.Code.ABORTED == 10 + assert Status.Code.OUT_OF_RANGE == 11 + assert Status.Code.UNIMPLEMENTED == 12 + assert Status.Code.INTERNAL == 13 + assert Status.Code.UNAVAILABLE == 14 + assert Status.Code.DATA_LOSS == 15 + assert Status.Code.UNAUTHENTICATED == 16 + assert Status.Code.FINISHED == 100 # Custom code + + def test_status_creation_methods(self): + """Test the convenience methods for creating Status instances.""" + # Test running status + running_status = Status.rollout_running() + assert running_status.code == Status.Code.OK + assert running_status.message == "Rollout is running" + assert running_status.details == [] + + # Test finished status + finished_status = Status.rollout_finished() + assert finished_status.code == Status.Code.FINISHED + assert finished_status.message == "Rollout finished successfully" + assert finished_status.details == [] + + # Test error status + error_status = Status.rollout_error("Something went wrong") + assert error_status.code == Status.Code.INTERNAL + assert error_status.message == "Something went wrong" + assert error_status.details == [] + + # Test error status with extra info + extra_info = {"error_code": "E001", "timestamp": "2024-01-01"} + error_status_with_info = Status.rollout_error("Something went wrong", extra_info) + assert error_status_with_info.code == Status.Code.INTERNAL + assert error_status_with_info.message == "Something went wrong" + assert len(error_status_with_info.details) == 1 + assert error_status_with_info.details[0]["@type"] == "type.googleapis.com/google.rpc.ErrorInfo" + assert error_status_with_info.details[0]["reason"] == "ROLLOUT_ERROR" + assert error_status_with_info.details[0]["domain"] == "evalprotocol.io" + assert error_status_with_info.details[0]["metadata"] == extra_info + + # Test stopped status + stopped_status = Status.rollout_stopped("User requested stop") + assert stopped_status.code == Status.Code.CANCELLED + assert stopped_status.message == "User requested stop" + assert stopped_status.details == [] + + # Test with termination reason + termination_status = Status.with_termination_reason("goal_reached") + assert termination_status.code == Status.Code.FINISHED + assert termination_status.message == "Rollout finished" + assert len(termination_status.details) == 1 + assert termination_status.details[0]["@type"] == "type.googleapis.com/google.rpc.ErrorInfo" + assert termination_status.details[0]["reason"] == "TERMINATION_REASON" + assert termination_status.details[0]["domain"] == "evalprotocol.io" + assert termination_status.details[0]["metadata"]["termination_reason"] == "goal_reached" + + # Test with termination reason and extra info + extra_info = {"steps": 10, "reward": 0.8} + termination_status_with_info = Status.with_termination_reason("goal_reached", extra_info) + assert termination_status_with_info.code == Status.Code.FINISHED + assert len(termination_status_with_info.details) == 2 + # First detail should be termination reason + assert termination_status_with_info.details[0]["reason"] == "TERMINATION_REASON" + # Second detail should be extra info + assert termination_status_with_info.details[1]["reason"] == "EXTRA_INFO" + assert termination_status_with_info.details[1]["metadata"] == extra_info + + def test_status_helper_methods(self): + """Test the helper methods for checking status types.""" + # Test is_running + running_status = Status.rollout_running() + assert running_status.is_running() is True + assert running_status.is_finished() is False + assert running_status.is_error() is False + assert running_status.is_stopped() is False + + # Test is_finished + finished_status = Status.rollout_finished() + assert finished_status.is_running() is False + assert finished_status.is_finished() is True + assert finished_status.is_error() is False + assert finished_status.is_stopped() is False + + # Test is_error + error_status = Status.rollout_error("Test error") + assert error_status.is_running() is False + assert error_status.is_finished() is False + assert error_status.is_error() is True + assert error_status.is_stopped() is False + + # Test is_stopped + stopped_status = Status.rollout_stopped("Test stop") + assert stopped_status.is_running() is False + assert stopped_status.is_finished() is False + assert stopped_status.is_error() is False + assert stopped_status.is_stopped() is True + + def test_get_termination_reason(self): + """Test extracting termination reason from status details.""" + # Status without termination reason + running_status = Status.rollout_running() + assert running_status.get_termination_reason() is None + + # Status with termination reason + termination_status = Status.with_termination_reason("goal_reached") + assert termination_status.get_termination_reason() == "goal_reached" + + # Status with termination reason and extra info + extra_info = {"steps": 10} + termination_status_with_info = Status.with_termination_reason("timeout", extra_info) + assert termination_status_with_info.get_termination_reason() == "timeout" + + def test_get_extra_info(self): + """Test extracting extra info from status details.""" + # Status without extra info + running_status = Status.rollout_running() + assert running_status.get_extra_info() is None + + # Status with only termination reason (no extra info) + termination_status = Status.with_termination_reason("goal_reached") + assert termination_status.get_extra_info() is None + + # Status with extra info + extra_info = {"steps": 10, "reward": 0.8} + error_status = Status.rollout_error("Test error", extra_info) + assert error_status.get_extra_info() == extra_info + + # Status with both termination reason and extra info + termination_status_with_info = Status.with_termination_reason("goal_reached", extra_info) + assert termination_status_with_info.get_extra_info() == extra_info + + def test_aip_193_compliance(self): + """Test that Status model follows AIP-193 standards.""" + # Test ErrorInfo structure + extra_info = {"error_code": "E001"} + error_status = Status.rollout_error("Test error", extra_info) + + assert len(error_status.details) == 1 + detail = error_status.details[0] + + # Check AIP-193 ErrorInfo structure + assert detail["@type"] == "type.googleapis.com/google.rpc.ErrorInfo" + assert detail["reason"] == "ROLLOUT_ERROR" + assert detail["domain"] == "evalprotocol.io" + assert detail["metadata"] == extra_info + + # Test multiple details + termination_status = Status.with_termination_reason("goal_reached", extra_info) + assert len(termination_status.details) == 2 + + # First detail should be termination reason + term_detail = termination_status.details[0] + assert term_detail["@type"] == "type.googleapis.com/google.rpc.ErrorInfo" + assert term_detail["reason"] == "TERMINATION_REASON" + + # Second detail should be extra info + extra_detail = termination_status.details[1] + assert extra_detail["@type"] == "type.googleapis.com/google.rpc.ErrorInfo" + assert extra_detail["reason"] == "EXTRA_INFO" + + def test_status_serialization(self): + """Test that Status can be serialized and deserialized.""" + original_status = Status.with_termination_reason("goal_reached", {"steps": 10}) + + # Test model_dump + status_dict = original_status.model_dump() + assert status_dict["code"] == Status.Code.FINISHED + assert status_dict["message"] == "Rollout finished" + assert len(status_dict["details"]) == 2 + + # Test model_validate + reconstructed_status = Status.model_validate(status_dict) + assert reconstructed_status.code == original_status.code + assert reconstructed_status.message == original_status.message + assert len(reconstructed_status.details) == len(original_status.details) + assert reconstructed_status.get_termination_reason() == "goal_reached" + assert reconstructed_status.get_extra_info() == {"steps": 10} + + def test_status_equality(self): + """Test Status equality and comparison.""" + status1 = Status.rollout_running() + status2 = Status.rollout_running() + status3 = Status.rollout_finished() + + # Same values should be equal + assert status1 == status2 + + # Different values should not be equal + assert status1 != status3 + + # Test hash + assert hash(status1) == hash(status2) + assert hash(status1) != hash(status3) + + +class TestStatusMigration: + """Test the migration from RolloutStatus to Status.""" + + def test_evaluation_row_default_status(self): + """Test that EvaluationRow has the correct default status.""" + row = EvaluationRow(messages=[]) + + # Should have rollout_status field (not status) + assert hasattr(row, "rollout_status") + assert not hasattr(row, "status") + + # Default status should be running + assert row.rollout_status.code == Status.Code.OK + assert row.rollout_status.message == "Rollout is running" + assert row.rollout_status.details == [] + + def test_backwards_compatibility_methods(self): + """Test the backwards compatibility methods.""" + row = EvaluationRow(messages=[]) + + # Test get_rollout_status + status = row.get_rollout_status() + assert status.code == Status.Code.OK + assert status.message == "Rollout is running" + + # Test set_rollout_status + new_status = Status.rollout_finished() + row.set_rollout_status(new_status) + assert row.rollout_status.code == Status.Code.FINISHED + assert row.rollout_status.message == "Rollout finished successfully" + + def test_status_transitions(self): + """Test transitioning between different status states.""" + row = EvaluationRow(messages=[]) + + # Start with running + assert row.rollout_status.is_running() + + # Transition to finished + row.rollout_status = Status.rollout_finished() + assert row.rollout_status.is_finished() + assert not row.rollout_status.is_running() + + # Transition to error + row.rollout_status = Status.rollout_error("Something went wrong") + assert row.rollout_status.is_error() + assert not row.rollout_status.is_finished() + + # Transition to stopped + row.rollout_status = Status.rollout_stopped("User requested stop") + assert row.rollout_status.is_stopped() + assert not row.rollout_status.is_error() + + def test_termination_reason_integration(self): + """Test integration of termination reason with status.""" + row = EvaluationRow(messages=[]) + + # Set status with termination reason + termination_status = Status.with_termination_reason("goal_reached", {"steps": 15}) + row.rollout_status = termination_status + + # Should be finished + assert row.rollout_status.is_finished() + + # Should have termination reason + assert row.rollout_status.get_termination_reason() == "goal_reached" + + # Should have extra info + extra_info = row.rollout_status.get_extra_info() + assert extra_info == {"steps": 15} + + def test_error_handling_integration(self): + """Test error handling integration with status.""" + row = EvaluationRow(messages=[]) + + # Set error status + error_info = {"error_code": "E001", "line": 42} + error_status = Status.rollout_error("Runtime error occurred", error_info) + row.rollout_status = error_status + + # Should be error + assert row.rollout_status.is_error() + + # Should have error details + assert row.rollout_status.get_extra_info() == error_info + + # Should not have termination reason + assert row.rollout_status.get_termination_reason() is None + + +class TestStatusEdgeCases: + """Test edge cases and error conditions.""" + + def test_empty_details(self): + """Test Status with empty details.""" + status = Status(code=Status.Code.OK, message="Test", details=[]) + assert status.details == [] + assert status.get_termination_reason() is None + assert status.get_extra_info() is None + + def test_malformed_details(self): + """Test Status with malformed details.""" + malformed_details = [ + {"not_type": "invalid", "reason": "TEST"}, + {"@type": "type.googleapis.com/google.rpc.ErrorInfo", "metadata": {"termination_reason": "test"}}, + ] + status = Status(code=Status.Code.OK, message="Test", details=malformed_details) + + # Should handle malformed details gracefully + assert status.get_termination_reason() == "test" + assert status.get_extra_info() is None + + def test_duplicate_detail_types(self): + """Test Status with duplicate detail types.""" + details = [ + { + "@type": "type.googleapis.com/google.rpc.ErrorInfo", + "reason": "TERMINATION_REASON", + "domain": "evalprotocol.io", + "metadata": {"termination_reason": "first"}, + }, + { + "@type": "type.googleapis.com/google.rpc.ErrorInfo", + "reason": "TERMINATION_REASON", + "domain": "evalprotocol.io", + "metadata": {"termination_reason": "second"}, + }, + ] + status = Status(code=Status.Code.OK, message="Test", details=details) + + # Should return the first termination reason found + assert status.get_termination_reason() == "first" + + def test_large_metadata(self): + """Test Status with large metadata.""" + large_metadata = {f"key_{i}": f"value_{i}" for i in range(100)} + status = Status.rollout_error("Test error", large_metadata) + + # Should handle large metadata + assert status.get_extra_info() == large_metadata + assert len(status.details) == 1 + assert status.details[0]["metadata"] == large_metadata + + +if __name__ == "__main__": + pytest.main([__file__]) diff --git a/vite-app/src/types/eval-protocol.ts b/vite-app/src/types/eval-protocol.ts index e6932b40..05ba8b6c 100644 --- a/vite-app/src/types/eval-protocol.ts +++ b/vite-app/src/types/eval-protocol.ts @@ -81,13 +81,11 @@ export const EvalMetadataSchema = z.object({ passed: z.boolean().optional().describe('Whether the evaluation passed based on the threshold') }); -// Rollout status model (matches Python RolloutStatus) -export const RolloutStatusSchema = z.object({ - status: z - .enum(['running', 'finished', 'error', 'stopped']) - .default('finished') - .describe('Status of the rollout.'), - error_message: z.string().optional().describe('Error message if the rollout failed.') +// AIP-193 compatible Status model (matches Python Status) +export const StatusSchema = z.object({ + code: z.number().describe('The status code (numeric value from google.rpc.Code enum)'), + message: z.string().describe('Developer-facing, human-readable debug message in English'), + details: z.array(z.record(z.string(), z.any())).default([]).describe('Additional error information, each packed in a google.protobuf.Any message format') }); export const ExecutionMetadataSchema = z.object({ @@ -101,7 +99,7 @@ export const EvaluationRowSchema = z.object({ messages: z.array(MessageSchema).describe('List of messages in the conversation/trajectory.'), tools: z.array(z.record(z.string(), z.any())).optional().describe('Available tools/functions that were provided to the agent.'), input_metadata: InputMetadataSchema.describe('Metadata related to the input (dataset info, model config, session data, etc.).'), - rollout_status: RolloutStatusSchema.default({ status: 'finished' }).describe('The status of the rollout.'), + rollout_status: StatusSchema.default({ code: 0, message: 'Rollout is running', details: [] }).describe('The status of the rollout following AIP-193 standards.'), execution_metadata: ExecutionMetadataSchema.optional().describe('Metadata about the execution of the evaluation.'), ground_truth: z.string().optional().describe('Optional ground truth reference for this evaluation.'), evaluation_result: EvaluateResultSchema.optional().describe('The evaluation result for this row/trajectory.'), @@ -171,7 +169,7 @@ export type InputMetadata = z.infer; export type CompletionUsage = z.infer; export type EvalMetadata = z.infer; export type EvaluationRow = z.infer; -export type RolloutStatus = z.infer; +export type Status = z.infer; export type ResourceServerConfig = z.infer; export type EvaluationCriteriaModel = z.infer; export type TaskDefinitionModel = z.infer; From f56b2c2f2bf52d30011816e813676fadbde705a2 Mon Sep 17 00:00:00 2001 From: Dylan Huang Date: Wed, 20 Aug 2025 10:21:43 -0700 Subject: [PATCH 02/11] update --- eval_protocol/models.py | 61 +++++++++++++++++++++++++++-------------- 1 file changed, 41 insertions(+), 20 deletions(-) diff --git a/eval_protocol/models.py b/eval_protocol/models.py index d5464300..c7e642bc 100644 --- a/eval_protocol/models.py +++ b/eval_protocol/models.py @@ -28,6 +28,15 @@ class ErrorInfo(BaseModel): metadata (Dict[str, Any]): Additional dynamic information as context. """ + # Constants for reason values + REASON_TERMINATION_REASON = "TERMINATION_REASON" + REASON_EXTRA_INFO = "EXTRA_INFO" + REASON_ROLLOUT_ERROR = "ROLLOUT_ERROR" + REASON_STOPPED = "STOPPED" + + # Domain constant + DOMAIN = "evalprotocol.io" + reason: str = Field(..., description="Short snake_case description of the error cause") domain: str = Field(..., description="Logical grouping for the error reason") metadata: Dict[str, Any] = Field(default_factory=dict, description="Additional dynamic information as context") @@ -42,24 +51,30 @@ def to_aip193_format(self) -> Dict[str, Any]: } @classmethod - def termination_reason(cls, reason: str) -> "ErrorInfo": + def termination_reason(cls, reason: Union[str, TerminationReason]) -> "ErrorInfo": """Create an ErrorInfo for termination reason.""" - return cls(reason="TERMINATION_REASON", domain="evalprotocol.io", metadata={"termination_reason": reason}) + # Convert TerminationReason enum to string if needed + reason_str = reason.value if isinstance(reason, TerminationReason) else reason + return cls( + reason=cls.REASON_TERMINATION_REASON, domain=cls.DOMAIN, metadata={"termination_reason": reason_str} + ) @classmethod def extra_info(cls, metadata: Dict[str, Any]) -> "ErrorInfo": """Create an ErrorInfo for extra information.""" - return cls(reason="EXTRA_INFO", domain="evalprotocol.io", metadata=metadata) + return cls(reason=cls.REASON_EXTRA_INFO, domain=cls.DOMAIN, metadata=metadata) @classmethod def rollout_error(cls, metadata: Dict[str, Any]) -> "ErrorInfo": """Create an ErrorInfo for rollout errors.""" - return cls(reason="ROLLOUT_ERROR", domain="evalprotocol.io", metadata=metadata) + return cls(reason=cls.REASON_ROLLOUT_ERROR, domain=cls.DOMAIN, metadata=metadata) @classmethod - def stopped_reason(cls, reason: str) -> "ErrorInfo": + def stopped_reason(cls, reason: Union[str, TerminationReason]) -> "ErrorInfo": """Create an ErrorInfo for stopped reason.""" - return cls(reason="STOPPED", domain="evalprotocol.io", metadata={"reason": reason}) + # Convert TerminationReason enum to string if needed + reason_str = reason.value if isinstance(reason, TerminationReason) else reason + return cls(reason=cls.REASON_STOPPED, domain=cls.DOMAIN, metadata={"reason": reason_str}) class Status(BaseModel): @@ -116,7 +131,9 @@ def rollout_running(cls) -> "Status": @classmethod def rollout_finished( - cls, termination_reason: Optional[str] = None, extra_info: Optional[Dict[str, Any]] = None + cls, + termination_reason: Optional[Union[str, TerminationReason]] = None, + extra_info: Optional[Dict[str, Any]] = None, ) -> "Status": """Create a status indicating the rollout finished.""" details = [] @@ -140,13 +157,15 @@ def error(cls, error_message: str, details: Optional[List[Dict[str, Any]]] = Non return cls(code=cls.Code.INTERNAL, message=error_message, details=details) @classmethod - def rollout_stopped(cls, reason: str = "Rollout stopped") -> "Status": + def rollout_stopped(cls, reason: Union[str, TerminationReason] = "Rollout stopped") -> "Status": """Create a status indicating the rollout was stopped.""" details = [ErrorInfo.stopped_reason(reason).to_aip193_format()] return cls(code=cls.Code.CANCELLED, message=reason, details=details) @classmethod - def with_termination_reason(cls, termination_reason: str, extra_info: Optional[Dict[str, Any]] = None) -> "Status": + def with_termination_reason( + cls, termination_reason: Union[str, TerminationReason], extra_info: Optional[Dict[str, Any]] = None + ) -> "Status": """Create a status indicating the rollout finished with termination reason.""" details = [ErrorInfo.termination_reason(termination_reason).to_aip193_format()] @@ -171,24 +190,26 @@ def is_stopped(self) -> bool: """Check if the status indicates the rollout was stopped.""" return self.code == self.Code.CANCELLED - def get_termination_reason(self) -> Optional[str]: + def get_termination_reason(self) -> Optional[TerminationReason]: """Extract termination reason from details if present.""" for detail in self.details: - if detail.get("@type") == "type.googleapis.com/google.rpc.ErrorInfo": - metadata = detail.get("metadata", {}) - if detail.get("reason") == "TERMINATION_REASON" and "termination_reason" in metadata: - return metadata["termination_reason"] + metadata = detail.get("metadata", {}) + if detail.get("reason") == ErrorInfo.REASON_TERMINATION_REASON and "termination_reason" in metadata: + try: + return TerminationReason.from_str(metadata["termination_reason"]) + except ValueError: + # If the reason is not a valid enum value, return None + return None return None def get_extra_info(self) -> Optional[Dict[str, Any]]: """Extract extra info from details if present.""" for detail in self.details: - if detail.get("@type") == "type.googleapis.com/google.rpc.ErrorInfo": - metadata = detail.get("metadata", {}) - reason = detail.get("reason") - # Skip termination_reason and stopped details, return other error info - if reason not in ["TERMINATION_REASON", "STOPPED"]: - return metadata + metadata = detail.get("metadata", {}) + reason = detail.get("reason") + # Skip termination_reason and stopped details, return other error info + if reason not in [ErrorInfo.REASON_TERMINATION_REASON, ErrorInfo.REASON_STOPPED]: + return metadata return None def __hash__(self) -> int: From 3bcbadafca84013ac5e3b0815d12c244471d92f4 Mon Sep 17 00:00:00 2001 From: Dylan Huang Date: Wed, 20 Aug 2025 10:26:47 -0700 Subject: [PATCH 03/11] fix --- eval_protocol/models.py | 2 +- eval_protocol/pytest/utils.py | 10 ++++------ 2 files changed, 5 insertions(+), 7 deletions(-) diff --git a/eval_protocol/models.py b/eval_protocol/models.py index c7e642bc..4fab5d05 100644 --- a/eval_protocol/models.py +++ b/eval_protocol/models.py @@ -122,7 +122,7 @@ class Code(int, Enum): UNAUTHENTICATED = 16 # Custom codes for rollout states (using higher numbers to avoid conflicts) - FINISHED = 100 # Custom code for rollout finished + FINISHED = 100 @classmethod def rollout_running(cls) -> "Status": diff --git a/eval_protocol/pytest/utils.py b/eval_protocol/pytest/utils.py index a20f5795..6b709782 100644 --- a/eval_protocol/pytest/utils.py +++ b/eval_protocol/pytest/utils.py @@ -282,7 +282,7 @@ async def execute_row_with_backoff(task: asyncio.Task, row: EvaluationRow) -> Ev try: # Try original task first result = await task - result.rollout_status.status = RolloutStatus.Status.FINISHED + result.rollout_status = Status.rollout_finished() return result except Exception as e: # NOTE: we perform these checks because we don't put the backoff decorator on initial batch call. we don't want to retry whole batch if anything fails. @@ -295,17 +295,15 @@ async def execute_row_with_backoff(task: asyncio.Task, row: EvaluationRow) -> Ev # Use shared backoff function for retryable exceptions try: result = await execute_row_with_backoff_retry(row) - result.rollout_status.status = RolloutStatus.Status.FINISHED + result.rollout_status = Status.rollout_finished() return result except Exception as retry_error: # Backoff gave up - row.rollout_status.status = RolloutStatus.Status.ERROR - # row.rollout_status.termination_reason = str(retry_error) + row.rollout_status = Status.rollout_error(str(retry_error)) return row else: # Non-retryable exception - fail immediately - row.rollout_status.status = RolloutStatus.Status.ERROR - # row.rollout_status.termination_reason = str(e) + row.rollout_status = Status.rollout_error(str(e)) return row # Process all tasks concurrently with backoff retry From c6334809406cd7402f30cf255e864ed411b0d56d Mon Sep 17 00:00:00 2001 From: Dylan Huang Date: Wed, 20 Aug 2025 10:47:41 -0700 Subject: [PATCH 04/11] fix test_status_model --- eval_protocol/models.py | 26 ++++++------ test_models_fix.py | 22 ++++++++++ tests/test_status_model.py | 82 ++++++++++++++------------------------ 3 files changed, 66 insertions(+), 64 deletions(-) create mode 100644 test_models_fix.py diff --git a/eval_protocol/models.py b/eval_protocol/models.py index 4fab5d05..e291d81f 100644 --- a/eval_protocol/models.py +++ b/eval_protocol/models.py @@ -1,7 +1,7 @@ import os from datetime import datetime from enum import Enum -from typing import Any, Dict, List, Literal, Optional, TypedDict, Union +from typing import Any, ClassVar, Dict, List, Literal, Optional, TypedDict, Union from openai.types import CompletionUsage from openai.types.chat.chat_completion_message import ( @@ -29,13 +29,13 @@ class ErrorInfo(BaseModel): """ # Constants for reason values - REASON_TERMINATION_REASON = "TERMINATION_REASON" - REASON_EXTRA_INFO = "EXTRA_INFO" - REASON_ROLLOUT_ERROR = "ROLLOUT_ERROR" - REASON_STOPPED = "STOPPED" + REASON_TERMINATION_REASON: ClassVar[str] = "TERMINATION_REASON" + REASON_EXTRA_INFO: ClassVar[str] = "EXTRA_INFO" + REASON_ROLLOUT_ERROR: ClassVar[str] = "ROLLOUT_ERROR" + REASON_STOPPED: ClassVar[str] = "STOPPED" # Domain constant - DOMAIN = "evalprotocol.io" + DOMAIN: ClassVar[str] = "evalprotocol.io" reason: str = Field(..., description="Short snake_case description of the error cause") domain: str = Field(..., description="Logical grouping for the error reason") @@ -51,7 +51,7 @@ def to_aip193_format(self) -> Dict[str, Any]: } @classmethod - def termination_reason(cls, reason: Union[str, TerminationReason]) -> "ErrorInfo": + def termination_reason(cls, reason: TerminationReason) -> "ErrorInfo": """Create an ErrorInfo for termination reason.""" # Convert TerminationReason enum to string if needed reason_str = reason.value if isinstance(reason, TerminationReason) else reason @@ -132,7 +132,7 @@ def rollout_running(cls) -> "Status": @classmethod def rollout_finished( cls, - termination_reason: Optional[Union[str, TerminationReason]] = None, + termination_reason: Optional[TerminationReason] = None, extra_info: Optional[Dict[str, Any]] = None, ) -> "Status": """Create a status indicating the rollout finished.""" @@ -157,14 +157,16 @@ def error(cls, error_message: str, details: Optional[List[Dict[str, Any]]] = Non return cls(code=cls.Code.INTERNAL, message=error_message, details=details) @classmethod - def rollout_stopped(cls, reason: Union[str, TerminationReason] = "Rollout stopped") -> "Status": + def rollout_stopped(cls, message: str, extra_info: Optional[Dict[str, Any]] = None) -> "Status": """Create a status indicating the rollout was stopped.""" - details = [ErrorInfo.stopped_reason(reason).to_aip193_format()] - return cls(code=cls.Code.CANCELLED, message=reason, details=details) + details = [] + if extra_info: + details.append(ErrorInfo.extra_info(extra_info).to_aip193_format()) + return cls(code=cls.Code.CANCELLED, message=message, details=details) @classmethod def with_termination_reason( - cls, termination_reason: Union[str, TerminationReason], extra_info: Optional[Dict[str, Any]] = None + cls, termination_reason: TerminationReason, extra_info: Optional[Dict[str, Any]] = None ) -> "Status": """Create a status indicating the rollout finished with termination reason.""" details = [ErrorInfo.termination_reason(termination_reason).to_aip193_format()] diff --git a/test_models_fix.py b/test_models_fix.py new file mode 100644 index 00000000..3216e8c0 --- /dev/null +++ b/test_models_fix.py @@ -0,0 +1,22 @@ +#!/usr/bin/env python3 +"""Test script to verify that models.py can be imported without Pydantic errors.""" + +try: + from eval_protocol.models import ErrorInfo, Status + + print("āœ… Successfully imported ErrorInfo and Status from models.py") + + # Test creating instances + error_info = ErrorInfo.termination_reason("test_reason") + print(f"āœ… Successfully created ErrorInfo: {error_info}") + + status = Status.rollout_running() + print(f"āœ… Successfully created Status: {status}") + + print("\nšŸŽ‰ All tests passed! The Pydantic error has been resolved.") + +except Exception as e: + print(f"āŒ Error importing models: {e}") + import traceback + + traceback.print_exc() diff --git a/tests/test_status_model.py b/tests/test_status_model.py index cdc5512b..6f4b7b70 100644 --- a/tests/test_status_model.py +++ b/tests/test_status_model.py @@ -11,7 +11,8 @@ """ import pytest -from eval_protocol.models import Status, EvaluationRow, Message, ErrorInfo +from eval_protocol.models import Status, EvaluationRow, ErrorInfo +from eval_protocol.types import TerminationReason class TestErrorInfoModel: @@ -41,10 +42,10 @@ def test_error_info_to_aip193_format(self): def test_error_info_factory_methods(self): """Test the factory methods for common error types.""" # Test termination_reason - term_error = ErrorInfo.termination_reason("goal_reached") + term_error = ErrorInfo.termination_reason(TerminationReason.CONTROL_PLANE_SIGNAL) assert term_error.reason == "TERMINATION_REASON" assert term_error.domain == "evalprotocol.io" - assert term_error.metadata["termination_reason"] == "goal_reached" + assert term_error.metadata["termination_reason"] == TerminationReason.CONTROL_PLANE_SIGNAL # Test extra_info extra_error = ErrorInfo.extra_info({"steps": 10, "reward": 0.8}) @@ -100,7 +101,7 @@ def test_status_creation_methods(self): # Test finished status finished_status = Status.rollout_finished() assert finished_status.code == Status.Code.FINISHED - assert finished_status.message == "Rollout finished successfully" + assert finished_status.message == "Rollout finished" assert finished_status.details == [] # Test error status @@ -127,18 +128,22 @@ def test_status_creation_methods(self): assert stopped_status.details == [] # Test with termination reason - termination_status = Status.with_termination_reason("goal_reached") + termination_status = Status.with_termination_reason(TerminationReason.CONTROL_PLANE_SIGNAL) assert termination_status.code == Status.Code.FINISHED assert termination_status.message == "Rollout finished" assert len(termination_status.details) == 1 assert termination_status.details[0]["@type"] == "type.googleapis.com/google.rpc.ErrorInfo" assert termination_status.details[0]["reason"] == "TERMINATION_REASON" assert termination_status.details[0]["domain"] == "evalprotocol.io" - assert termination_status.details[0]["metadata"]["termination_reason"] == "goal_reached" + assert ( + termination_status.details[0]["metadata"]["termination_reason"] == TerminationReason.CONTROL_PLANE_SIGNAL + ) # Test with termination reason and extra info extra_info = {"steps": 10, "reward": 0.8} - termination_status_with_info = Status.with_termination_reason("goal_reached", extra_info) + termination_status_with_info = Status.with_termination_reason( + TerminationReason.CONTROL_PLANE_SIGNAL, extra_info + ) assert termination_status_with_info.code == Status.Code.FINISHED assert len(termination_status_with_info.details) == 2 # First detail should be termination reason @@ -179,27 +184,31 @@ def test_status_helper_methods(self): def test_get_termination_reason(self): """Test extracting termination reason from status details.""" + # Status without termination reason running_status = Status.rollout_running() assert running_status.get_termination_reason() is None # Status with termination reason - termination_status = Status.with_termination_reason("goal_reached") - assert termination_status.get_termination_reason() == "goal_reached" + termination_status = Status.with_termination_reason(TerminationReason.CONTROL_PLANE_SIGNAL) + assert termination_status.get_termination_reason() == TerminationReason.CONTROL_PLANE_SIGNAL # Status with termination reason and extra info extra_info = {"steps": 10} - termination_status_with_info = Status.with_termination_reason("timeout", extra_info) - assert termination_status_with_info.get_termination_reason() == "timeout" + termination_status_with_info = Status.with_termination_reason( + TerminationReason.CONTROL_PLANE_SIGNAL, extra_info + ) + assert termination_status_with_info.get_termination_reason() == TerminationReason.CONTROL_PLANE_SIGNAL def test_get_extra_info(self): """Test extracting extra info from status details.""" + # Status without extra info running_status = Status.rollout_running() assert running_status.get_extra_info() is None # Status with only termination reason (no extra info) - termination_status = Status.with_termination_reason("goal_reached") + termination_status = Status.with_termination_reason(TerminationReason.CONTROL_PLANE_SIGNAL) assert termination_status.get_extra_info() is None # Status with extra info @@ -208,7 +217,9 @@ def test_get_extra_info(self): assert error_status.get_extra_info() == extra_info # Status with both termination reason and extra info - termination_status_with_info = Status.with_termination_reason("goal_reached", extra_info) + termination_status_with_info = Status.with_termination_reason( + TerminationReason.CONTROL_PLANE_SIGNAL, extra_info + ) assert termination_status_with_info.get_extra_info() == extra_info def test_aip_193_compliance(self): @@ -227,7 +238,7 @@ def test_aip_193_compliance(self): assert detail["metadata"] == extra_info # Test multiple details - termination_status = Status.with_termination_reason("goal_reached", extra_info) + termination_status = Status.with_termination_reason(TerminationReason.CONTROL_PLANE_SIGNAL, extra_info) assert len(termination_status.details) == 2 # First detail should be termination reason @@ -242,7 +253,7 @@ def test_aip_193_compliance(self): def test_status_serialization(self): """Test that Status can be serialized and deserialized.""" - original_status = Status.with_termination_reason("goal_reached", {"steps": 10}) + original_status = Status.with_termination_reason(TerminationReason.CONTROL_PLANE_SIGNAL, {"steps": 10}) # Test model_dump status_dict = original_status.model_dump() @@ -255,7 +266,7 @@ def test_status_serialization(self): assert reconstructed_status.code == original_status.code assert reconstructed_status.message == original_status.message assert len(reconstructed_status.details) == len(original_status.details) - assert reconstructed_status.get_termination_reason() == "goal_reached" + assert reconstructed_status.get_termination_reason() == TerminationReason.CONTROL_PLANE_SIGNAL assert reconstructed_status.get_extra_info() == {"steps": 10} def test_status_equality(self): @@ -304,7 +315,7 @@ def test_backwards_compatibility_methods(self): new_status = Status.rollout_finished() row.set_rollout_status(new_status) assert row.rollout_status.code == Status.Code.FINISHED - assert row.rollout_status.message == "Rollout finished successfully" + assert row.rollout_status.message == "Rollout finished" def test_status_transitions(self): """Test transitioning between different status states.""" @@ -333,14 +344,14 @@ def test_termination_reason_integration(self): row = EvaluationRow(messages=[]) # Set status with termination reason - termination_status = Status.with_termination_reason("goal_reached", {"steps": 15}) + termination_status = Status.with_termination_reason(TerminationReason.CONTROL_PLANE_SIGNAL, {"steps": 15}) row.rollout_status = termination_status # Should be finished assert row.rollout_status.is_finished() # Should have termination reason - assert row.rollout_status.get_termination_reason() == "goal_reached" + assert row.rollout_status.get_termination_reason() == TerminationReason.CONTROL_PLANE_SIGNAL # Should have extra info extra_info = row.rollout_status.get_extra_info() @@ -375,39 +386,6 @@ def test_empty_details(self): assert status.get_termination_reason() is None assert status.get_extra_info() is None - def test_malformed_details(self): - """Test Status with malformed details.""" - malformed_details = [ - {"not_type": "invalid", "reason": "TEST"}, - {"@type": "type.googleapis.com/google.rpc.ErrorInfo", "metadata": {"termination_reason": "test"}}, - ] - status = Status(code=Status.Code.OK, message="Test", details=malformed_details) - - # Should handle malformed details gracefully - assert status.get_termination_reason() == "test" - assert status.get_extra_info() is None - - def test_duplicate_detail_types(self): - """Test Status with duplicate detail types.""" - details = [ - { - "@type": "type.googleapis.com/google.rpc.ErrorInfo", - "reason": "TERMINATION_REASON", - "domain": "evalprotocol.io", - "metadata": {"termination_reason": "first"}, - }, - { - "@type": "type.googleapis.com/google.rpc.ErrorInfo", - "reason": "TERMINATION_REASON", - "domain": "evalprotocol.io", - "metadata": {"termination_reason": "second"}, - }, - ] - status = Status(code=Status.Code.OK, message="Test", details=details) - - # Should return the first termination reason found - assert status.get_termination_reason() == "first" - def test_large_metadata(self): """Test Status with large metadata.""" large_metadata = {f"key_{i}": f"value_{i}" for i in range(100)} From 7933dd182c5e65001bceafd93fe2bcddcceb1616 Mon Sep 17 00:00:00 2001 From: Dylan Huang Date: Wed, 20 Aug 2025 10:49:50 -0700 Subject: [PATCH 05/11] Remove backwards compatibility methods for rollout status from EvaluationRow and associated tests. --- eval_protocol/models.py | 8 ------- tests/test_status_migration_integration.py | 28 ---------------------- tests/test_status_model.py | 15 ------------ 3 files changed, 51 deletions(-) diff --git a/eval_protocol/models.py b/eval_protocol/models.py index e291d81f..4fa4b88e 100644 --- a/eval_protocol/models.py +++ b/eval_protocol/models.py @@ -576,14 +576,6 @@ def is_trajectory_evaluation(self) -> bool: and len(self.evaluation_result.step_outputs) > 0 ) - def get_rollout_status(self) -> Status: - """Get the rollout status (backwards compatibility method).""" - return self.rollout_status - - def set_rollout_status(self, status: Status) -> None: - """Set the rollout status (backwards compatibility method).""" - self.rollout_status = status - def get_conversation_length(self) -> int: """Returns the number of messages in the conversation.""" return len(self.messages) diff --git a/tests/test_status_migration_integration.py b/tests/test_status_migration_integration.py index 12291fdd..e616acc2 100644 --- a/tests/test_status_migration_integration.py +++ b/tests/test_status_migration_integration.py @@ -42,34 +42,6 @@ def test_rollout_status_field_access(self): assert row.rollout_status.code == Status.Code.FINISHED -class TestBackwardsCompatibilityMethods: - """Test the backwards compatibility methods.""" - - def test_get_rollout_status_method(self): - """Test the get_rollout_status method.""" - row = EvaluationRow(messages=[]) - - # Method should return the current rollout_status - status = row.get_rollout_status() - assert status.code == Status.Code.OK - assert status.message == "Rollout is running" - - # Should be the same object reference - assert status is row.rollout_status - - def test_set_rollout_status_method(self): - """Test the set_rollout_status method.""" - row = EvaluationRow(messages=[]) - - # Method should update the rollout_status - new_status = Status.rollout_error("Test error") - row.set_rollout_status(new_status) - - assert row.rollout_status.code == Status.Code.INTERNAL - assert row.rollout_status.message == "Test error" - assert row.rollout_status is new_status - - class TestStatusTransitions: """Test transitioning between different status states.""" diff --git a/tests/test_status_model.py b/tests/test_status_model.py index 6f4b7b70..6b9e497b 100644 --- a/tests/test_status_model.py +++ b/tests/test_status_model.py @@ -302,21 +302,6 @@ def test_evaluation_row_default_status(self): assert row.rollout_status.message == "Rollout is running" assert row.rollout_status.details == [] - def test_backwards_compatibility_methods(self): - """Test the backwards compatibility methods.""" - row = EvaluationRow(messages=[]) - - # Test get_rollout_status - status = row.get_rollout_status() - assert status.code == Status.Code.OK - assert status.message == "Rollout is running" - - # Test set_rollout_status - new_status = Status.rollout_finished() - row.set_rollout_status(new_status) - assert row.rollout_status.code == Status.Code.FINISHED - assert row.rollout_status.message == "Rollout finished" - def test_status_transitions(self): """Test transitioning between different status states.""" row = EvaluationRow(messages=[]) From 39b17dc2be53b6302b4a5ff902ac3c719413a5c4 Mon Sep 17 00:00:00 2001 From: Dylan Huang Date: Wed, 20 Aug 2025 10:58:05 -0700 Subject: [PATCH 06/11] fix test_status_migration_integration --- eval_protocol/models.py | 4 ++-- tests/test_status_migration_integration.py | 24 +++++++++++----------- 2 files changed, 14 insertions(+), 14 deletions(-) diff --git a/eval_protocol/models.py b/eval_protocol/models.py index 4fa4b88e..95adeae1 100644 --- a/eval_protocol/models.py +++ b/eval_protocol/models.py @@ -148,7 +148,7 @@ def rollout_error(cls, error_message: str, extra_info: Optional[Dict[str, Any]] """Create a status indicating the rollout failed with an error.""" details = [] if extra_info: - details.append(ErrorInfo.rollout_error(extra_info).to_aip193_format()) + details.append(ErrorInfo.extra_info(extra_info).to_aip193_format()) return cls.error(error_message, details) @classmethod @@ -210,7 +210,7 @@ def get_extra_info(self) -> Optional[Dict[str, Any]]: metadata = detail.get("metadata", {}) reason = detail.get("reason") # Skip termination_reason and stopped details, return other error info - if reason not in [ErrorInfo.REASON_TERMINATION_REASON, ErrorInfo.REASON_STOPPED]: + if reason in [ErrorInfo.REASON_EXTRA_INFO]: return metadata return None diff --git a/tests/test_status_migration_integration.py b/tests/test_status_migration_integration.py index e616acc2..881f6433 100644 --- a/tests/test_status_migration_integration.py +++ b/tests/test_status_migration_integration.py @@ -106,14 +106,14 @@ def test_termination_reason_in_status_details(self): row = EvaluationRow(messages=[]) # Set status with termination reason - termination_status = Status.with_termination_reason("goal_reached") + termination_status = Status.with_termination_reason(TerminationReason.CONTROL_PLANE_SIGNAL) row.rollout_status = termination_status # Should be finished assert row.rollout_status.is_finished() # Should have termination reason in details - assert row.rollout_status.get_termination_reason() == "goal_reached" + assert row.rollout_status.get_termination_reason() == TerminationReason.CONTROL_PLANE_SIGNAL # Check details structure assert len(row.rollout_status.details) == 1 @@ -121,18 +121,18 @@ def test_termination_reason_in_status_details(self): assert detail["@type"] == "type.googleapis.com/google.rpc.ErrorInfo" assert detail["reason"] == "TERMINATION_REASON" assert detail["domain"] == "evalprotocol.io" - assert detail["metadata"]["termination_reason"] == "goal_reached" + assert detail["metadata"]["termination_reason"] == TerminationReason.CONTROL_PLANE_SIGNAL def test_termination_reason_with_extra_info(self): """Test termination reason with additional extra info.""" row = EvaluationRow(messages=[]) extra_info = {"steps": 10, "reward": 0.8} - termination_status = Status.with_termination_reason("timeout", extra_info) + termination_status = Status.with_termination_reason(TerminationReason.USER_STOP, extra_info) row.rollout_status = termination_status # Should have both termination reason and extra info - assert row.rollout_status.get_termination_reason() == "timeout" + assert row.rollout_status.get_termination_reason() == TerminationReason.USER_STOP assert row.rollout_status.get_extra_info() == extra_info # Check details structure @@ -157,13 +157,13 @@ def test_multiple_termination_reasons(self): "@type": "type.googleapis.com/google.rpc.ErrorInfo", "reason": "TERMINATION_REASON", "domain": "evalprotocol.io", - "metadata": {"termination_reason": "first"}, + "metadata": {"termination_reason": TerminationReason.USER_STOP}, }, { "@type": "type.googleapis.com/google.rpc.ErrorInfo", "reason": "TERMINATION_REASON", "domain": "evalprotocol.io", - "metadata": {"termination_reason": "second"}, + "metadata": {"termination_reason": TerminationReason.SKIPPABLE_ERROR}, }, ] @@ -171,7 +171,7 @@ def test_multiple_termination_reasons(self): row.rollout_status = status # Should return the first termination reason found - assert row.rollout_status.get_termination_reason() == "first" + assert row.rollout_status.get_termination_reason() == TerminationReason.USER_STOP class TestErrorHandlingIntegration: @@ -204,7 +204,7 @@ def test_error_status_with_metadata(self): assert len(row.rollout_status.details) == 1 detail = row.rollout_status.details[0] assert detail["@type"] == "type.googleapis.com/google.rpc.ErrorInfo" - assert detail["reason"] == "ROLLOUT_ERROR" + assert detail["reason"] == "EXTRA_INFO" assert detail["domain"] == "evalprotocol.io" assert detail["metadata"] == error_info @@ -331,7 +331,7 @@ def test_status_model_validate(self): # Set a complex status extra_info = {"steps": 10, "reward": 0.8} - original_status = Status.with_termination_reason("goal_reached", extra_info) + original_status = Status.with_termination_reason(TerminationReason.CONTROL_PLANE_SIGNAL, extra_info) row.rollout_status = original_status # Dump and reconstruct @@ -344,7 +344,7 @@ def test_status_model_validate(self): assert len(reconstructed_status.details) == len(original_status.details) # Should preserve functionality - assert reconstructed_status.get_termination_reason() == "goal_reached" + assert reconstructed_status.get_termination_reason() == TerminationReason.CONTROL_PLANE_SIGNAL assert reconstructed_status.get_extra_info() == extra_info @@ -377,7 +377,7 @@ def test_malformed_status_details(self): row.rollout_status = malformed_status # Should handle gracefully - assert row.rollout_status.get_termination_reason() == "test" + assert row.rollout_status.get_termination_reason() is None assert row.rollout_status.get_extra_info() is None def test_large_metadata_handling(self): From c00ac2d49b955d879086e96759a4701e722ce7fc Mon Sep 17 00:00:00 2001 From: Dylan Huang Date: Wed, 20 Aug 2025 11:01:34 -0700 Subject: [PATCH 07/11] fix test_migration_Changes --- tests/test_migration_changes.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/tests/test_migration_changes.py b/tests/test_migration_changes.py index fed33e0b..bebabb4b 100644 --- a/tests/test_migration_changes.py +++ b/tests/test_migration_changes.py @@ -23,7 +23,7 @@ def test_trajectory_terminated_status_creation(self): # Mock trajectory with termination trajectory = Mock() trajectory.terminated = True - trajectory.termination_reason = "goal_reached" + trajectory.termination_reason = TerminationReason.CONTROL_PLANE_SIGNAL trajectory.control_plane_summary = {"error_message": "No errors"} # Create evaluation row @@ -63,7 +63,7 @@ def test_trajectory_terminated_status_creation(self): assert row.rollout_status.is_finished() # Verify termination reason - assert row.rollout_status.get_termination_reason() == "goal_reached" + assert row.rollout_status.get_termination_reason() == TerminationReason.CONTROL_PLANE_SIGNAL # Verify extra info assert row.rollout_status.get_extra_info() == {"error_message": "No errors"} @@ -98,7 +98,7 @@ def test_trajectory_terminated_without_error_message(self): # Mock trajectory with termination but no error trajectory = Mock() trajectory.terminated = True - trajectory.termination_reason = "timeout" + trajectory.termination_reason = TerminationReason.USER_STOP trajectory.control_plane_summary = {} # Create evaluation row @@ -137,7 +137,7 @@ def test_trajectory_terminated_without_error_message(self): # Verify the status assert row.rollout_status.code == Status.Code.FINISHED assert row.rollout_status.is_finished() - assert row.rollout_status.get_termination_reason() == "timeout" + assert row.rollout_status.get_termination_reason() == TerminationReason.USER_STOP # Should not have extra info since there was no error message assert row.rollout_status.get_extra_info() is None @@ -300,19 +300,19 @@ def test_termination_reason_integration(self): row = EvaluationRow(messages=[]) # Test with termination reason - termination_status = Status.with_termination_reason("goal_reached") + termination_status = Status.with_termination_reason(TerminationReason.CONTROL_PLANE_SIGNAL) row.rollout_status = termination_status assert row.rollout_status.is_finished() - assert row.rollout_status.get_termination_reason() == "goal_reached" + assert row.rollout_status.get_termination_reason() == TerminationReason.CONTROL_PLANE_SIGNAL # Test with termination reason and extra info extra_info = {"steps": 10, "reward": 0.8} - termination_status_with_info = Status.with_termination_reason("timeout", extra_info) + termination_status_with_info = Status.with_termination_reason(TerminationReason.USER_STOP, extra_info) row.rollout_status = termination_status_with_info assert row.rollout_status.is_finished() - assert row.rollout_status.get_termination_reason() == "timeout" + assert row.rollout_status.get_termination_reason() == TerminationReason.USER_STOP assert row.rollout_status.get_extra_info() == extra_info def test_error_handling_integration(self): From 925df657b80e66530cb4de8bc88652d6d9503b77 Mon Sep 17 00:00:00 2001 From: Dylan Huang Date: Wed, 20 Aug 2025 11:02:11 -0700 Subject: [PATCH 08/11] delete --- test_models_fix.py | 22 ---------------------- 1 file changed, 22 deletions(-) delete mode 100644 test_models_fix.py diff --git a/test_models_fix.py b/test_models_fix.py deleted file mode 100644 index 3216e8c0..00000000 --- a/test_models_fix.py +++ /dev/null @@ -1,22 +0,0 @@ -#!/usr/bin/env python3 -"""Test script to verify that models.py can be imported without Pydantic errors.""" - -try: - from eval_protocol.models import ErrorInfo, Status - - print("āœ… Successfully imported ErrorInfo and Status from models.py") - - # Test creating instances - error_info = ErrorInfo.termination_reason("test_reason") - print(f"āœ… Successfully created ErrorInfo: {error_info}") - - status = Status.rollout_running() - print(f"āœ… Successfully created Status: {status}") - - print("\nšŸŽ‰ All tests passed! The Pydantic error has been resolved.") - -except Exception as e: - print(f"āŒ Error importing models: {e}") - import traceback - - traceback.print_exc() From 38829b462ce77d72119b6362e9731ae100a5f1fa Mon Sep 17 00:00:00 2001 From: Dylan Huang Date: Wed, 20 Aug 2025 11:04:31 -0700 Subject: [PATCH 09/11] fix test_retry_mechanism --- tests/test_retry_mechanism.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/tests/test_retry_mechanism.py b/tests/test_retry_mechanism.py index f485ab3c..863d7838 100644 --- a/tests/test_retry_mechanism.py +++ b/tests/test_retry_mechanism.py @@ -191,9 +191,9 @@ async def process_single_row(row: EvaluationRow) -> EvaluationRow: def test_fail_fast_exceptions(row: EvaluationRow) -> EvaluationRow: """Test that fail-fast exceptions like ValueError are not retried.""" print( - f"šŸ“Š EVALUATED: {row.execution_metadata.rollout_id} ({'SUCCESS' if row.rollout_status.status == 'finished' else 'FAILURE'})" + f"šŸ“Š EVALUATED: {row.execution_metadata.rollout_id} ({'SUCCESS' if row.rollout_status.is_finished() else 'FAILURE'})" ) - score = 1.0 if row.rollout_status.status == "finished" else 0.0 + score = 1.0 if row.rollout_status.is_finished() else 0.0 row.evaluation_result = EvaluateResult(score=score) return row @@ -283,8 +283,8 @@ def custom_http_giveup(e): def test_custom_giveup_function(row: EvaluationRow) -> EvaluationRow: """Test custom giveup function behavior.""" task_content = row.messages[0].content if row.messages else "" - print(f"šŸ“Š EVALUATED: {task_content} ({'SUCCESS' if row.rollout_status.status == 'finished' else 'FAILURE'})") - score = 1.0 if row.rollout_status.status == "finished" else 0.0 + print(f"šŸ“Š EVALUATED: {task_content} ({'SUCCESS' if row.rollout_status.is_finished() else 'FAILURE'})") + score = 1.0 if row.rollout_status.is_finished() else 0.0 row.evaluation_result = EvaluateResult(score=score) return row @@ -368,9 +368,9 @@ def simple_4xx_giveup(e): def test_simple_giveup_function(row: EvaluationRow) -> EvaluationRow: """Test that giveup function prevents retries immediately.""" print( - f"šŸ“Š EVALUATED: {row.execution_metadata.rollout_id} ({'SUCCESS' if row.rollout_status.status == 'finished' else 'FAILURE'})" + f"šŸ“Š EVALUATED: {row.execution_metadata.rollout_id} ({'SUCCESS' if row.rollout_status.is_finished() else 'FAILURE'})" ) - score = 1.0 if row.rollout_status.status == "finished" else 0.0 + score = 1.0 if row.rollout_status.is_finished() else 0.0 row.evaluation_result = EvaluateResult(score=score) return row From 775b07e5ad0de24db9706332101953fb5ab1a825 Mon Sep 17 00:00:00 2001 From: Dylan Huang Date: Wed, 20 Aug 2025 11:16:29 -0700 Subject: [PATCH 10/11] fix tests --- eval_protocol/models.py | 12 -------- ...es.py => test_status_migration_changes.py} | 8 +++--- tests/test_status_migration_integration.py | 10 +++---- tests/test_status_model.py | 28 ++++++++----------- 4 files changed, 20 insertions(+), 38 deletions(-) rename tests/{test_migration_changes.py => test_status_migration_changes.py} (97%) diff --git a/eval_protocol/models.py b/eval_protocol/models.py index 95adeae1..a0d638c7 100644 --- a/eval_protocol/models.py +++ b/eval_protocol/models.py @@ -164,18 +164,6 @@ def rollout_stopped(cls, message: str, extra_info: Optional[Dict[str, Any]] = No details.append(ErrorInfo.extra_info(extra_info).to_aip193_format()) return cls(code=cls.Code.CANCELLED, message=message, details=details) - @classmethod - def with_termination_reason( - cls, termination_reason: TerminationReason, extra_info: Optional[Dict[str, Any]] = None - ) -> "Status": - """Create a status indicating the rollout finished with termination reason.""" - details = [ErrorInfo.termination_reason(termination_reason).to_aip193_format()] - - if extra_info: - details.append(ErrorInfo.extra_info(extra_info).to_aip193_format()) - - return cls(code=cls.Code.FINISHED, message="Rollout finished", details=details) - def is_running(self) -> bool: """Check if the status indicates the rollout is running.""" return self.code == self.Code.OK and self.message == "Rollout is running" diff --git a/tests/test_migration_changes.py b/tests/test_status_migration_changes.py similarity index 97% rename from tests/test_migration_changes.py rename to tests/test_status_migration_changes.py index bebabb4b..0d398e2b 100644 --- a/tests/test_migration_changes.py +++ b/tests/test_status_migration_changes.py @@ -300,7 +300,7 @@ def test_termination_reason_integration(self): row = EvaluationRow(messages=[]) # Test with termination reason - termination_status = Status.with_termination_reason(TerminationReason.CONTROL_PLANE_SIGNAL) + termination_status = Status.rollout_finished(TerminationReason.CONTROL_PLANE_SIGNAL) row.rollout_status = termination_status assert row.rollout_status.is_finished() @@ -308,7 +308,7 @@ def test_termination_reason_integration(self): # Test with termination reason and extra info extra_info = {"steps": 10, "reward": 0.8} - termination_status_with_info = Status.with_termination_reason(TerminationReason.USER_STOP, extra_info) + termination_status_with_info = Status.rollout_finished(TerminationReason.USER_STOP, extra_info) row.rollout_status = termination_status_with_info assert row.rollout_status.is_finished() @@ -392,7 +392,7 @@ def test_termination_reason_structure_compliance(self): row = EvaluationRow(messages=[]) # Create status with termination reason - termination_status = Status.with_termination_reason("goal_reached") + termination_status = Status.rollout_finished(TerminationReason.CONTROL_PLANE_SIGNAL) row.rollout_status = termination_status # Check AIP-193 structure @@ -411,7 +411,7 @@ def test_multiple_details_compliance(self): # Create status with both termination reason and extra info extra_info = {"steps": 15, "reward": 0.9} - status = Status.with_termination_reason("goal_reached", extra_info) + status = Status.rollout_finished(TerminationReason.CONTROL_PLANE_SIGNAL, extra_info) row.rollout_status = status # Should have two details diff --git a/tests/test_status_migration_integration.py b/tests/test_status_migration_integration.py index 881f6433..c4a81801 100644 --- a/tests/test_status_migration_integration.py +++ b/tests/test_status_migration_integration.py @@ -106,7 +106,7 @@ def test_termination_reason_in_status_details(self): row = EvaluationRow(messages=[]) # Set status with termination reason - termination_status = Status.with_termination_reason(TerminationReason.CONTROL_PLANE_SIGNAL) + termination_status = Status.rollout_finished(TerminationReason.CONTROL_PLANE_SIGNAL) row.rollout_status = termination_status # Should be finished @@ -128,7 +128,7 @@ def test_termination_reason_with_extra_info(self): row = EvaluationRow(messages=[]) extra_info = {"steps": 10, "reward": 0.8} - termination_status = Status.with_termination_reason(TerminationReason.USER_STOP, extra_info) + termination_status = Status.rollout_finished(TerminationReason.USER_STOP, extra_info) row.rollout_status = termination_status # Should have both termination reason and extra info @@ -262,7 +262,7 @@ def test_multiple_detail_types(self): # Create status with both termination reason and extra info extra_info = {"steps": 15, "reward": 0.9} - status = Status.with_termination_reason("goal_reached", extra_info) + status = Status.rollout_finished(TerminationReason.CONTROL_PLANE_SIGNAL, extra_info) row.rollout_status = status # Should have two details @@ -309,7 +309,7 @@ def test_status_model_dump(self): # Set a complex status extra_info = {"steps": 10, "reward": 0.8} - termination_status = Status.with_termination_reason("goal_reached", extra_info) + termination_status = Status.rollout_finished(TerminationReason.CONTROL_PLANE_SIGNAL, extra_info) row.rollout_status = termination_status # Dump to dict @@ -331,7 +331,7 @@ def test_status_model_validate(self): # Set a complex status extra_info = {"steps": 10, "reward": 0.8} - original_status = Status.with_termination_reason(TerminationReason.CONTROL_PLANE_SIGNAL, extra_info) + original_status = Status.rollout_finished(TerminationReason.CONTROL_PLANE_SIGNAL, extra_info) row.rollout_status = original_status # Dump and reconstruct diff --git a/tests/test_status_model.py b/tests/test_status_model.py index 6b9e497b..3725b921 100644 --- a/tests/test_status_model.py +++ b/tests/test_status_model.py @@ -117,7 +117,7 @@ def test_status_creation_methods(self): assert error_status_with_info.message == "Something went wrong" assert len(error_status_with_info.details) == 1 assert error_status_with_info.details[0]["@type"] == "type.googleapis.com/google.rpc.ErrorInfo" - assert error_status_with_info.details[0]["reason"] == "ROLLOUT_ERROR" + assert error_status_with_info.details[0]["reason"] == "EXTRA_INFO" assert error_status_with_info.details[0]["domain"] == "evalprotocol.io" assert error_status_with_info.details[0]["metadata"] == extra_info @@ -128,7 +128,7 @@ def test_status_creation_methods(self): assert stopped_status.details == [] # Test with termination reason - termination_status = Status.with_termination_reason(TerminationReason.CONTROL_PLANE_SIGNAL) + termination_status = Status.rollout_finished(TerminationReason.CONTROL_PLANE_SIGNAL) assert termination_status.code == Status.Code.FINISHED assert termination_status.message == "Rollout finished" assert len(termination_status.details) == 1 @@ -141,9 +141,7 @@ def test_status_creation_methods(self): # Test with termination reason and extra info extra_info = {"steps": 10, "reward": 0.8} - termination_status_with_info = Status.with_termination_reason( - TerminationReason.CONTROL_PLANE_SIGNAL, extra_info - ) + termination_status_with_info = Status.rollout_finished(TerminationReason.CONTROL_PLANE_SIGNAL, extra_info) assert termination_status_with_info.code == Status.Code.FINISHED assert len(termination_status_with_info.details) == 2 # First detail should be termination reason @@ -190,14 +188,12 @@ def test_get_termination_reason(self): assert running_status.get_termination_reason() is None # Status with termination reason - termination_status = Status.with_termination_reason(TerminationReason.CONTROL_PLANE_SIGNAL) + termination_status = Status.rollout_finished(TerminationReason.CONTROL_PLANE_SIGNAL) assert termination_status.get_termination_reason() == TerminationReason.CONTROL_PLANE_SIGNAL # Status with termination reason and extra info extra_info = {"steps": 10} - termination_status_with_info = Status.with_termination_reason( - TerminationReason.CONTROL_PLANE_SIGNAL, extra_info - ) + termination_status_with_info = Status.rollout_finished(TerminationReason.CONTROL_PLANE_SIGNAL, extra_info) assert termination_status_with_info.get_termination_reason() == TerminationReason.CONTROL_PLANE_SIGNAL def test_get_extra_info(self): @@ -208,7 +204,7 @@ def test_get_extra_info(self): assert running_status.get_extra_info() is None # Status with only termination reason (no extra info) - termination_status = Status.with_termination_reason(TerminationReason.CONTROL_PLANE_SIGNAL) + termination_status = Status.rollout_finished(TerminationReason.CONTROL_PLANE_SIGNAL) assert termination_status.get_extra_info() is None # Status with extra info @@ -217,9 +213,7 @@ def test_get_extra_info(self): assert error_status.get_extra_info() == extra_info # Status with both termination reason and extra info - termination_status_with_info = Status.with_termination_reason( - TerminationReason.CONTROL_PLANE_SIGNAL, extra_info - ) + termination_status_with_info = Status.rollout_finished(TerminationReason.CONTROL_PLANE_SIGNAL, extra_info) assert termination_status_with_info.get_extra_info() == extra_info def test_aip_193_compliance(self): @@ -233,12 +227,12 @@ def test_aip_193_compliance(self): # Check AIP-193 ErrorInfo structure assert detail["@type"] == "type.googleapis.com/google.rpc.ErrorInfo" - assert detail["reason"] == "ROLLOUT_ERROR" + assert detail["reason"] == "EXTRA_INFO" assert detail["domain"] == "evalprotocol.io" assert detail["metadata"] == extra_info # Test multiple details - termination_status = Status.with_termination_reason(TerminationReason.CONTROL_PLANE_SIGNAL, extra_info) + termination_status = Status.rollout_finished(TerminationReason.CONTROL_PLANE_SIGNAL, extra_info) assert len(termination_status.details) == 2 # First detail should be termination reason @@ -253,7 +247,7 @@ def test_aip_193_compliance(self): def test_status_serialization(self): """Test that Status can be serialized and deserialized.""" - original_status = Status.with_termination_reason(TerminationReason.CONTROL_PLANE_SIGNAL, {"steps": 10}) + original_status = Status.rollout_finished(TerminationReason.CONTROL_PLANE_SIGNAL, {"steps": 10}) # Test model_dump status_dict = original_status.model_dump() @@ -329,7 +323,7 @@ def test_termination_reason_integration(self): row = EvaluationRow(messages=[]) # Set status with termination reason - termination_status = Status.with_termination_reason(TerminationReason.CONTROL_PLANE_SIGNAL, {"steps": 15}) + termination_status = Status.rollout_finished(TerminationReason.CONTROL_PLANE_SIGNAL, {"steps": 15}) row.rollout_status = termination_status # Should be finished From c0b1656db762c12708cda837040b70fd5c4b950e Mon Sep 17 00:00:00 2001 From: Dylan Huang Date: Wed, 20 Aug 2025 11:23:33 -0700 Subject: [PATCH 11/11] remove unused --- eval_protocol/models.py | 22 ---------------- tests/test_status_migration_changes.py | 6 ----- tests/test_status_migration_integration.py | 14 ---------- tests/test_status_model.py | 30 ---------------------- 4 files changed, 72 deletions(-) diff --git a/eval_protocol/models.py b/eval_protocol/models.py index a0d638c7..61b96e54 100644 --- a/eval_protocol/models.py +++ b/eval_protocol/models.py @@ -31,8 +31,6 @@ class ErrorInfo(BaseModel): # Constants for reason values REASON_TERMINATION_REASON: ClassVar[str] = "TERMINATION_REASON" REASON_EXTRA_INFO: ClassVar[str] = "EXTRA_INFO" - REASON_ROLLOUT_ERROR: ClassVar[str] = "ROLLOUT_ERROR" - REASON_STOPPED: ClassVar[str] = "STOPPED" # Domain constant DOMAIN: ClassVar[str] = "evalprotocol.io" @@ -64,18 +62,6 @@ def extra_info(cls, metadata: Dict[str, Any]) -> "ErrorInfo": """Create an ErrorInfo for extra information.""" return cls(reason=cls.REASON_EXTRA_INFO, domain=cls.DOMAIN, metadata=metadata) - @classmethod - def rollout_error(cls, metadata: Dict[str, Any]) -> "ErrorInfo": - """Create an ErrorInfo for rollout errors.""" - return cls(reason=cls.REASON_ROLLOUT_ERROR, domain=cls.DOMAIN, metadata=metadata) - - @classmethod - def stopped_reason(cls, reason: Union[str, TerminationReason]) -> "ErrorInfo": - """Create an ErrorInfo for stopped reason.""" - # Convert TerminationReason enum to string if needed - reason_str = reason.value if isinstance(reason, TerminationReason) else reason - return cls(reason=cls.REASON_STOPPED, domain=cls.DOMAIN, metadata={"reason": reason_str}) - class Status(BaseModel): """ @@ -156,14 +142,6 @@ def error(cls, error_message: str, details: Optional[List[Dict[str, Any]]] = Non """Create a status indicating the rollout failed with an error.""" return cls(code=cls.Code.INTERNAL, message=error_message, details=details) - @classmethod - def rollout_stopped(cls, message: str, extra_info: Optional[Dict[str, Any]] = None) -> "Status": - """Create a status indicating the rollout was stopped.""" - details = [] - if extra_info: - details.append(ErrorInfo.extra_info(extra_info).to_aip193_format()) - return cls(code=cls.Code.CANCELLED, message=message, details=details) - def is_running(self) -> bool: """Check if the status indicates the rollout is running.""" return self.code == self.Code.OK and self.message == "Rollout is running" diff --git a/tests/test_status_migration_changes.py b/tests/test_status_migration_changes.py index 0d398e2b..66bde969 100644 --- a/tests/test_status_migration_changes.py +++ b/tests/test_status_migration_changes.py @@ -290,11 +290,6 @@ def test_status_creation_methods_integration(self): assert row.rollout_status.is_error() assert row.rollout_status.code == Status.Code.INTERNAL - # Test stopped status - row.rollout_status = Status.rollout_stopped("User stop") - assert row.rollout_status.is_stopped() - assert row.rollout_status.code == Status.Code.CANCELLED - def test_termination_reason_integration(self): """Test integration of termination reason with status.""" row = EvaluationRow(messages=[]) @@ -434,7 +429,6 @@ def test_status_code_compliance(self): (Status.rollout_running(), Status.Code.OK), (Status.rollout_finished(), Status.Code.FINISHED), # Custom code (Status.rollout_error("Test"), Status.Code.INTERNAL), - (Status.rollout_stopped("Test"), Status.Code.CANCELLED), ] for status, expected_code in statuses: diff --git a/tests/test_status_migration_integration.py b/tests/test_status_migration_integration.py index c4a81801..e9beb080 100644 --- a/tests/test_status_migration_integration.py +++ b/tests/test_status_migration_integration.py @@ -71,19 +71,6 @@ def test_running_to_error_transition(self): assert not row.rollout_status.is_running() assert row.rollout_status.is_error() - def test_running_to_stopped_transition(self): - """Test transition from running to stopped.""" - row = EvaluationRow(messages=[]) - - # Start with running - assert row.rollout_status.is_running() - assert not row.rollout_status.is_stopped() - - # Transition to stopped - row.rollout_status = Status.rollout_stopped("User requested stop") - assert not row.rollout_status.is_running() - assert row.rollout_status.is_stopped() - def test_error_to_finished_transition(self): """Test transition from error to finished.""" row = EvaluationRow(messages=[]) @@ -292,7 +279,6 @@ def test_status_code_mapping(self): (Status.rollout_running(), Status.Code.OK), (Status.rollout_finished(), Status.Code.FINISHED), (Status.rollout_error("Test"), Status.Code.INTERNAL), - (Status.rollout_stopped("Test"), Status.Code.CANCELLED), ] for status, expected_code in statuses: diff --git a/tests/test_status_model.py b/tests/test_status_model.py index 3725b921..459e8497 100644 --- a/tests/test_status_model.py +++ b/tests/test_status_model.py @@ -53,18 +53,6 @@ def test_error_info_factory_methods(self): assert extra_error.domain == "evalprotocol.io" assert extra_error.metadata == {"steps": 10, "reward": 0.8} - # Test rollout_error - rollout_error = ErrorInfo.rollout_error({"error_code": "E001"}) - assert rollout_error.reason == "ROLLOUT_ERROR" - assert rollout_error.domain == "evalprotocol.io" - assert rollout_error.metadata == {"error_code": "E001"} - - # Test stopped_reason - stopped_error = ErrorInfo.stopped_reason("user_request") - assert stopped_error.reason == "STOPPED" - assert stopped_error.domain == "evalprotocol.io" - assert stopped_error.metadata["reason"] == "user_request" - class TestStatusModel: """Test the AIP-193 compatible Status model.""" @@ -121,12 +109,6 @@ def test_status_creation_methods(self): assert error_status_with_info.details[0]["domain"] == "evalprotocol.io" assert error_status_with_info.details[0]["metadata"] == extra_info - # Test stopped status - stopped_status = Status.rollout_stopped("User requested stop") - assert stopped_status.code == Status.Code.CANCELLED - assert stopped_status.message == "User requested stop" - assert stopped_status.details == [] - # Test with termination reason termination_status = Status.rollout_finished(TerminationReason.CONTROL_PLANE_SIGNAL) assert termination_status.code == Status.Code.FINISHED @@ -173,13 +155,6 @@ def test_status_helper_methods(self): assert error_status.is_error() is True assert error_status.is_stopped() is False - # Test is_stopped - stopped_status = Status.rollout_stopped("Test stop") - assert stopped_status.is_running() is False - assert stopped_status.is_finished() is False - assert stopped_status.is_error() is False - assert stopped_status.is_stopped() is True - def test_get_termination_reason(self): """Test extracting termination reason from status details.""" @@ -313,11 +288,6 @@ def test_status_transitions(self): assert row.rollout_status.is_error() assert not row.rollout_status.is_finished() - # Transition to stopped - row.rollout_status = Status.rollout_stopped("User requested stop") - assert row.rollout_status.is_stopped() - assert not row.rollout_status.is_error() - def test_termination_reason_integration(self): """Test integration of termination reason with status.""" row = EvaluationRow(messages=[])