diff --git a/eval_protocol/mcp/execution/manager.py b/eval_protocol/mcp/execution/manager.py index 6b1163e9..8e461b49 100644 --- a/eval_protocol/mcp/execution/manager.py +++ b/eval_protocol/mcp/execution/manager.py @@ -20,7 +20,7 @@ from vendor.tau2.data_model.message import AssistantMessage, UserMessage from vendor.tau2.user.user_simulator import UserSimulator -from ...models import EvaluationRow, InputMetadata, Message, RolloutStatus +from ...models import EvaluationRow, InputMetadata, Message, Status from ...types import TerminationReason, Trajectory, NonSkippableException if TYPE_CHECKING: @@ -136,15 +136,14 @@ async def _execute_with_semaphore(idx): } if trajectory.terminated: - evaluation_row.rollout_status.termination_reason = trajectory.termination_reason - evaluation_row.rollout_status.status = RolloutStatus.Status.FINISHED - # preserve the true error mesage if there are any + extra_info = None if trajectory.control_plane_summary.get("error_message"): - evaluation_row.rollout_status.extra_info = { - "error_message": trajectory.control_plane_summary.get("error_message") - } + extra_info = {"error_message": trajectory.control_plane_summary.get("error_message")} + evaluation_row.rollout_status = Status.rollout_finished( + termination_reason=trajectory.termination_reason, extra_info=extra_info + ) else: - evaluation_row.rollout_status.status = RolloutStatus.Status.RUNNING + evaluation_row.rollout_status = Status.rollout_running() return evaluation_row diff --git a/eval_protocol/models.py b/eval_protocol/models.py index 83a0f178..61b96e54 100644 --- a/eval_protocol/models.py +++ b/eval_protocol/models.py @@ -1,7 +1,7 @@ import os from datetime import datetime from enum import Enum -from typing import Any, Dict, List, Literal, Optional, TypedDict, Union +from typing import Any, ClassVar, Dict, List, Literal, Optional, TypedDict, Union from openai.types import CompletionUsage from openai.types.chat.chat_completion_message import ( @@ -15,6 +15,188 @@ from eval_protocol.types import TerminationReason +class ErrorInfo(BaseModel): + """ + AIP-193 ErrorInfo model for structured error details. + + This model follows Google's AIP-193 standard for ErrorInfo: + https://google.aip.dev/193#errorinfo + + Attributes: + reason (str): A short snake_case description of the cause of the error. + domain (str): The logical grouping to which the reason belongs. + metadata (Dict[str, Any]): Additional dynamic information as context. + """ + + # Constants for reason values + REASON_TERMINATION_REASON: ClassVar[str] = "TERMINATION_REASON" + REASON_EXTRA_INFO: ClassVar[str] = "EXTRA_INFO" + + # Domain constant + DOMAIN: ClassVar[str] = "evalprotocol.io" + + reason: str = Field(..., description="Short snake_case description of the error cause") + domain: str = Field(..., description="Logical grouping for the error reason") + metadata: Dict[str, Any] = Field(default_factory=dict, description="Additional dynamic information as context") + + def to_aip193_format(self) -> Dict[str, Any]: + """Convert to AIP-193 format with @type field.""" + return { + "@type": "type.googleapis.com/google.rpc.ErrorInfo", + "reason": self.reason, + "domain": self.domain, + "metadata": self.metadata, + } + + @classmethod + def termination_reason(cls, reason: TerminationReason) -> "ErrorInfo": + """Create an ErrorInfo for termination reason.""" + # Convert TerminationReason enum to string if needed + reason_str = reason.value if isinstance(reason, TerminationReason) else reason + return cls( + reason=cls.REASON_TERMINATION_REASON, domain=cls.DOMAIN, metadata={"termination_reason": reason_str} + ) + + @classmethod + def extra_info(cls, metadata: Dict[str, Any]) -> "ErrorInfo": + """Create an ErrorInfo for extra information.""" + return cls(reason=cls.REASON_EXTRA_INFO, domain=cls.DOMAIN, metadata=metadata) + + +class Status(BaseModel): + """ + AIP-193 compatible Status model for standardized error responses. + + This model follows Google's AIP-193 standard for error handling: + https://google.aip.dev/193 + + Attributes: + code (int): The status code, must be the numeric value of one of the elements + of google.rpc.Code enum (e.g., 5 for NOT_FOUND). + message (str): Developer-facing, human-readable debug message in English. + details (List[Dict[str, Any]]): Additional error information, each packed in + a google.protobuf.Any message format. + """ + + code: "Status.Code" = Field(..., description="The status code from google.rpc.Code enum") + message: str = Field(..., description="Developer-facing, human-readable debug message in English") + details: List[Dict[str, Any]] = Field( + default_factory=list, + description="Additional error information, each packed in a google.protobuf.Any message format", + ) + + # Convenience constants for common status codes + class Code(int, Enum): + """Common gRPC status codes as defined in google.rpc.Code""" + + OK = 0 + CANCELLED = 1 + UNKNOWN = 2 + INVALID_ARGUMENT = 3 + DEADLINE_EXCEEDED = 4 + NOT_FOUND = 5 + ALREADY_EXISTS = 6 + PERMISSION_DENIED = 7 + RESOURCE_EXHAUSTED = 8 + FAILED_PRECONDITION = 9 + ABORTED = 10 + OUT_OF_RANGE = 11 + UNIMPLEMENTED = 12 + INTERNAL = 13 + UNAVAILABLE = 14 + DATA_LOSS = 15 + UNAUTHENTICATED = 16 + + # Custom codes for rollout states (using higher numbers to avoid conflicts) + FINISHED = 100 + + @classmethod + def rollout_running(cls) -> "Status": + """Create a status indicating the rollout is running.""" + return cls(code=cls.Code.OK, message="Rollout is running", details=[]) + + @classmethod + def rollout_finished( + cls, + termination_reason: Optional[TerminationReason] = None, + extra_info: Optional[Dict[str, Any]] = None, + ) -> "Status": + """Create a status indicating the rollout finished.""" + details = [] + if termination_reason: + details.append(ErrorInfo.termination_reason(termination_reason).to_aip193_format()) + if extra_info: + details.append(ErrorInfo.extra_info(extra_info).to_aip193_format()) + return cls(code=cls.Code.FINISHED, message="Rollout finished", details=details) + + @classmethod + def rollout_error(cls, error_message: str, extra_info: Optional[Dict[str, Any]] = None) -> "Status": + """Create a status indicating the rollout failed with an error.""" + details = [] + if extra_info: + details.append(ErrorInfo.extra_info(extra_info).to_aip193_format()) + return cls.error(error_message, details) + + @classmethod + def error(cls, error_message: str, details: Optional[List[Dict[str, Any]]] = None) -> "Status": + """Create a status indicating the rollout failed with an error.""" + return cls(code=cls.Code.INTERNAL, message=error_message, details=details) + + def is_running(self) -> bool: + """Check if the status indicates the rollout is running.""" + return self.code == self.Code.OK and self.message == "Rollout is running" + + def is_finished(self) -> bool: + """Check if the status indicates the rollout finished successfully.""" + return self.code == self.Code.FINISHED + + def is_error(self) -> bool: + """Check if the status indicates the rollout failed with an error.""" + return self.code == self.Code.INTERNAL + + def is_stopped(self) -> bool: + """Check if the status indicates the rollout was stopped.""" + return self.code == self.Code.CANCELLED + + def get_termination_reason(self) -> Optional[TerminationReason]: + """Extract termination reason from details if present.""" + for detail in self.details: + metadata = detail.get("metadata", {}) + if detail.get("reason") == ErrorInfo.REASON_TERMINATION_REASON and "termination_reason" in metadata: + try: + return TerminationReason.from_str(metadata["termination_reason"]) + except ValueError: + # If the reason is not a valid enum value, return None + return None + return None + + def get_extra_info(self) -> Optional[Dict[str, Any]]: + """Extract extra info from details if present.""" + for detail in self.details: + metadata = detail.get("metadata", {}) + reason = detail.get("reason") + # Skip termination_reason and stopped details, return other error info + if reason in [ErrorInfo.REASON_EXTRA_INFO]: + return metadata + return None + + def __hash__(self) -> int: + """Generate a hash for the Status object.""" + # Use a stable hash based on code, message, and details + import hashlib + + # Create a stable string representation + hash_data = f"{self.code}:{self.message}:{len(self.details)}" + + # Add details content for more uniqueness + for detail in sorted(self.details, key=lambda x: str(x)): + hash_data += f":{str(detail)}" + + # Generate hash + hash_obj = hashlib.sha256(hash_data.encode("utf-8")) + return int.from_bytes(hash_obj.digest()[:8], byteorder="big") + + class ChatCompletionContentPartTextParam(BaseModel): text: str = Field(..., description="The text content.") type: Literal["text"] = Field("text", description="The type of the content part.") @@ -289,27 +471,6 @@ class ExecutionMetadata(BaseModel): ) -class RolloutStatus(BaseModel): - """Status of the rollout.""" - - """ - running: Unfinished rollout which is still in progress. - finished: Rollout finished. - error: Rollout failed due to unexpected error. The rollout record should be discard. - """ - - class Status(str, Enum): - RUNNING = "running" - FINISHED = "finished" - ERROR = "error" - - status: Status = Field(Status.RUNNING, description="Status of the rollout.") - termination_reason: Optional[TerminationReason] = Field( - None, description="reason of the rollout status, mapped to values in TerminationReason" - ) - extra_info: Optional[Dict[str, Any]] = Field(None, description="Extra information about the rollout status.") - - class EvaluationRow(BaseModel): """ Unified data structure for a single evaluation unit that contains messages, @@ -334,9 +495,9 @@ class EvaluationRow(BaseModel): description="Metadata related to the input (dataset info, model config, session data, etc.).", ) - rollout_status: RolloutStatus = Field( - default_factory=RolloutStatus, - description="The status of the rollout.", + rollout_status: Status = Field( + default_factory=Status.rollout_running, + description="The status of the rollout following AIP-193 standards.", ) # Ground truth reference (moved from EvaluateResult to top level) diff --git a/eval_protocol/pytest/plugin.py b/eval_protocol/pytest/plugin.py index 81b36420..460eeb14 100644 --- a/eval_protocol/pytest/plugin.py +++ b/eval_protocol/pytest/plugin.py @@ -64,7 +64,7 @@ def pytest_addoption(parser) -> None: action="store", type=int, default=0, - help=("Failed rollouts (with rollout_status.status == 'error') will be retried up to this many times."), + help=("Failed rollouts (with rollout_status.code indicating error) will be retried up to this many times."), ) group.addoption( "--ep-fail-on-max-retry", diff --git a/eval_protocol/pytest/utils.py b/eval_protocol/pytest/utils.py index 429c3328..6b709782 100644 --- a/eval_protocol/pytest/utils.py +++ b/eval_protocol/pytest/utils.py @@ -6,7 +6,7 @@ from typing import Any, Callable, Dict, List, Literal, Optional, Union from eval_protocol.dataset_logger.dataset_logger import DatasetLogger -from eval_protocol.models import EvalMetadata, EvaluationRow, RolloutStatus +from eval_protocol.models import EvalMetadata, EvaluationRow, Status from eval_protocol.pytest.rollout_processor import RolloutProcessor from eval_protocol.pytest.types import ( CompletionParams, @@ -282,7 +282,7 @@ async def execute_row_with_backoff(task: asyncio.Task, row: EvaluationRow) -> Ev try: # Try original task first result = await task - result.rollout_status.status = RolloutStatus.Status.FINISHED + result.rollout_status = Status.rollout_finished() return result except Exception as e: # NOTE: we perform these checks because we don't put the backoff decorator on initial batch call. we don't want to retry whole batch if anything fails. @@ -295,17 +295,15 @@ async def execute_row_with_backoff(task: asyncio.Task, row: EvaluationRow) -> Ev # Use shared backoff function for retryable exceptions try: result = await execute_row_with_backoff_retry(row) - result.rollout_status.status = RolloutStatus.Status.FINISHED + result.rollout_status = Status.rollout_finished() return result except Exception as retry_error: # Backoff gave up - row.rollout_status.status = RolloutStatus.Status.ERROR - # row.rollout_status.termination_reason = str(retry_error) + row.rollout_status = Status.rollout_error(str(retry_error)) return row else: # Non-retryable exception - fail immediately - row.rollout_status.status = RolloutStatus.Status.ERROR - # row.rollout_status.termination_reason = str(e) + row.rollout_status = Status.rollout_error(str(e)) return row # Process all tasks concurrently with backoff retry diff --git a/tests/test_retry_mechanism.py b/tests/test_retry_mechanism.py index 6895ad55..863d7838 100644 --- a/tests/test_retry_mechanism.py +++ b/tests/test_retry_mechanism.py @@ -11,7 +11,7 @@ import pytest -from eval_protocol.models import EvaluateResult, EvaluationRow, Message, RolloutStatus +from eval_protocol.models import EvaluateResult, EvaluationRow, Message, Status from eval_protocol.pytest.evaluation_test import evaluation_test from eval_protocol.pytest.rollout_processor import RolloutProcessor from eval_protocol.pytest.types import RolloutProcessorConfig @@ -95,11 +95,11 @@ async def process_single_row( def test_retry_mechanism(row: EvaluationRow) -> EvaluationRow: """MOCK TEST: Tests that retry mechanism works - one task fails on first attempt, succeeds on retry.""" print( - f"📊 EVALUATED: {row.execution_metadata.rollout_id} ({'SUCCESS' if row.rollout_status.status == 'finished' else 'FAILURE'})" + f"📊 EVALUATED: {row.execution_metadata.rollout_id} ({'SUCCESS' if row.rollout_status.is_finished() else 'FAILURE'})" ) # Assign a score based on success/failure - score = 1.0 if row.rollout_status.status == "finished" else 0.0 + score = 1.0 if row.rollout_status.is_finished() else 0.0 row.evaluation_result = EvaluateResult(score=score) return row @@ -191,9 +191,9 @@ async def process_single_row(row: EvaluationRow) -> EvaluationRow: def test_fail_fast_exceptions(row: EvaluationRow) -> EvaluationRow: """Test that fail-fast exceptions like ValueError are not retried.""" print( - f"📊 EVALUATED: {row.execution_metadata.rollout_id} ({'SUCCESS' if row.rollout_status.status == 'finished' else 'FAILURE'})" + f"📊 EVALUATED: {row.execution_metadata.rollout_id} ({'SUCCESS' if row.rollout_status.is_finished() else 'FAILURE'})" ) - score = 1.0 if row.rollout_status.status == "finished" else 0.0 + score = 1.0 if row.rollout_status.is_finished() else 0.0 row.evaluation_result = EvaluateResult(score=score) return row @@ -283,8 +283,8 @@ def custom_http_giveup(e): def test_custom_giveup_function(row: EvaluationRow) -> EvaluationRow: """Test custom giveup function behavior.""" task_content = row.messages[0].content if row.messages else "" - print(f"📊 EVALUATED: {task_content} ({'SUCCESS' if row.rollout_status.status == 'finished' else 'FAILURE'})") - score = 1.0 if row.rollout_status.status == "finished" else 0.0 + print(f"📊 EVALUATED: {task_content} ({'SUCCESS' if row.rollout_status.is_finished() else 'FAILURE'})") + score = 1.0 if row.rollout_status.is_finished() else 0.0 row.evaluation_result = EvaluateResult(score=score) return row @@ -368,9 +368,9 @@ def simple_4xx_giveup(e): def test_simple_giveup_function(row: EvaluationRow) -> EvaluationRow: """Test that giveup function prevents retries immediately.""" print( - f"📊 EVALUATED: {row.execution_metadata.rollout_id} ({'SUCCESS' if row.rollout_status.status == 'finished' else 'FAILURE'})" + f"📊 EVALUATED: {row.execution_metadata.rollout_id} ({'SUCCESS' if row.rollout_status.is_finished() else 'FAILURE'})" ) - score = 1.0 if row.rollout_status.status == "finished" else 0.0 + score = 1.0 if row.rollout_status.is_finished() else 0.0 row.evaluation_result = EvaluateResult(score=score) return row diff --git a/tests/test_status_migration_changes.py b/tests/test_status_migration_changes.py new file mode 100644 index 00000000..66bde969 --- /dev/null +++ b/tests/test_status_migration_changes.py @@ -0,0 +1,440 @@ +#!/usr/bin/env python3 +""" +Tests for the migration changes in the existing codebase. + +This test suite verifies that: +- All migrated code works correctly with the new Status model +- The field name remains as 'rollout_status' +- All helper methods work as expected +- AIP-193 compliance is maintained +""" + +import pytest +from unittest.mock import Mock, patch, MagicMock +from eval_protocol.models import Status, EvaluationRow, Message +from eval_protocol.types import TerminationReason + + +class TestMCPExecutionManagerMigration: + """Test the migration changes in MCP execution manager.""" + + def test_trajectory_terminated_status_creation(self): + """Test that terminated trajectory creates correct status.""" + # Mock trajectory with termination + trajectory = Mock() + trajectory.terminated = True + trajectory.termination_reason = TerminationReason.CONTROL_PLANE_SIGNAL + trajectory.control_plane_summary = {"error_message": "No errors"} + + # Create evaluation row + row = EvaluationRow(messages=[]) + + # Simulate the status assignment from MCP execution manager + extra_info = {"error_message": trajectory.control_plane_summary.get("error_message")} + + row.rollout_status = Status( + code=Status.Code.FINISHED, + message="Rollout finished", + details=[ + { + "@type": "type.googleapis.com/google.rpc.ErrorInfo", + "reason": "TERMINATION_REASON", + "domain": "evalprotocol.io", + "metadata": {"termination_reason": trajectory.termination_reason}, + } + ] + + ( + [ + { + "@type": "type.googleapis.com/google.rpc.ErrorInfo", + "reason": "EXTRA_INFO", + "domain": "evalprotocol.io", + "metadata": extra_info, + } + ] + if extra_info + else [] + ), + ) + + # Verify the status + assert row.rollout_status.code == Status.Code.FINISHED + assert row.rollout_status.message == "Rollout finished" + assert row.rollout_status.is_finished() + + # Verify termination reason + assert row.rollout_status.get_termination_reason() == TerminationReason.CONTROL_PLANE_SIGNAL + + # Verify extra info + assert row.rollout_status.get_extra_info() == {"error_message": "No errors"} + + # Verify details structure + assert len(row.rollout_status.details) == 2 + assert row.rollout_status.details[0]["reason"] == "TERMINATION_REASON" + assert row.rollout_status.details[1]["reason"] == "EXTRA_INFO" + + def test_trajectory_running_status_creation(self): + """Test that running trajectory creates correct status.""" + # Mock trajectory that's still running + trajectory = Mock() + trajectory.terminated = False + + # Create evaluation row + row = EvaluationRow(messages=[]) + + # Simulate the status assignment from MCP execution manager + row.rollout_status = Status(code=Status.Code.OK, message="Rollout is running", details=[]) + + # Verify the status + assert row.rollout_status.code == Status.Code.OK + assert row.rollout_status.message == "Rollout is running" + assert row.rollout_status.is_running() + assert not row.rollout_status.is_finished() + assert not row.rollout_status.is_error() + assert not row.rollout_status.is_stopped() + + def test_trajectory_terminated_without_error_message(self): + """Test terminated trajectory without error message.""" + # Mock trajectory with termination but no error + trajectory = Mock() + trajectory.terminated = True + trajectory.termination_reason = TerminationReason.USER_STOP + trajectory.control_plane_summary = {} + + # Create evaluation row + row = EvaluationRow(messages=[]) + + # Simulate the status assignment + extra_info = None + if trajectory.control_plane_summary.get("error_message"): + extra_info = {"error_message": trajectory.control_plane_summary.get("error_message")} + + row.rollout_status = Status( + code=Status.Code.FINISHED, + message="Rollout finished", + details=[ + { + "@type": "type.googleapis.com/google.rpc.ErrorInfo", + "reason": "TERMINATION_REASON", + "domain": "evalprotocol.io", + "metadata": {"termination_reason": trajectory.termination_reason}, + } + ] + + ( + [ + { + "@type": "type.googleapis.com/google.rpc.ErrorInfo", + "reason": "EXTRA_INFO", + "domain": "evalprotocol.io", + "metadata": extra_info, + } + ] + if extra_info + else [] + ), + ) + + # Verify the status + assert row.rollout_status.code == Status.Code.FINISHED + assert row.rollout_status.is_finished() + assert row.rollout_status.get_termination_reason() == TerminationReason.USER_STOP + + # Should not have extra info since there was no error message + assert row.rollout_status.get_extra_info() is None + + # Should only have termination reason detail + assert len(row.rollout_status.details) == 1 + assert row.rollout_status.details[0]["reason"] == "TERMINATION_REASON" + + +class TestPytestUtilsMigration: + """Test the migration changes in pytest utils.""" + + def test_retry_success_status_update(self): + """Test that retry success updates status correctly.""" + row = EvaluationRow(messages=[]) + + # Simulate the status update from pytest utils + row.rollout_status = Status(code=Status.Code.FINISHED, message="Rollout finished successfully", details=[]) + + # Verify the status + assert row.rollout_status.code == Status.Code.FINISHED + assert row.rollout_status.message == "Rollout finished successfully" + assert row.rollout_status.is_finished() + assert not row.rollout_status.is_running() + + def test_retry_failure_status_update(self): + """Test that retry failure updates status correctly.""" + row = EvaluationRow(messages=[]) + + # Simulate the status update from pytest utils + row.rollout_status = Status( + code=Status.Code.INTERNAL, + message="Test error message", + details=[ + { + "@type": "type.googleapis.com/google.rpc.ErrorInfo", + "reason": "ROLLOUT_ERROR", + "domain": "evalprotocol.io", + "metadata": {}, + } + ], + ) + + # Verify the status + assert row.rollout_status.code == Status.Code.INTERNAL + assert row.rollout_status.message == "Test error message" + assert row.rollout_status.is_error() + assert not row.rollout_status.is_finished() + + def test_initial_processor_success_status_update(self): + """Test that initial processor success updates status correctly.""" + row = EvaluationRow(messages=[]) + + # Simulate the status update from pytest utils + row.rollout_status = Status(code=Status.Code.FINISHED, message="Rollout finished successfully", details=[]) + + # Verify the status + assert row.rollout_status.code == Status.Code.FINISHED + assert row.rollout_status.message == "Rollout finished successfully" + assert row.rollout_status.is_finished() + + def test_initial_processor_failure_status_update(self): + """Test that initial processor failure updates status correctly.""" + row = EvaluationRow(messages=[]) + + # Simulate the status update from pytest utils + row.rollout_status = Status( + code=Status.Code.INTERNAL, + message="Runtime error occurred", + details=[ + { + "@type": "type.googleapis.com/google.rpc.ErrorInfo", + "reason": "ROLLOUT_ERROR", + "domain": "evalprotocol.io", + "metadata": {}, + } + ], + ) + + # Verify the status + assert row.rollout_status.code == Status.Code.INTERNAL + assert row.rollout_status.message == "Runtime error occurred" + assert row.rollout_status.is_error() + + def test_error_status_checking(self): + """Test that error status checking works correctly.""" + row = EvaluationRow(messages=[]) + + # Set error status + row.rollout_status = Status.rollout_error("Test error") + + # Should be detected as error + assert row.rollout_status.is_error() + + # Should be able to get termination reason (None for error status) + assert row.rollout_status.get_termination_reason() is None + + +class TestTestRetryMechanismMigration: + """Test the migration changes in test_retry_mechanism.py.""" + + def test_success_failure_detection(self): + """Test that success/failure detection works correctly.""" + row = EvaluationRow(messages=[]) + + # Test success case + row.rollout_status = Status.rollout_finished() + success = row.rollout_status.is_finished() + assert success is True + + # Test failure case + row.rollout_status = Status.rollout_error("Test error") + success = row.rollout_status.is_finished() + assert success is False + + def test_score_assignment_based_on_status(self): + """Test that score assignment works based on status.""" + row = EvaluationRow(messages=[]) + + # Test success score + row.rollout_status = Status.rollout_finished() + score = 1.0 if row.rollout_status.is_finished() else 0.0 + assert score == 1.0 + + # Test failure score + row.rollout_status = Status.rollout_error("Test error") + score = 1.0 if row.rollout_status.is_finished() else 0.0 + assert score == 0.0 + + +class TestStatusModelIntegration: + """Test integration of Status model with existing functionality.""" + + def test_status_creation_methods_integration(self): + """Test that all status creation methods work together.""" + row = EvaluationRow(messages=[]) + + # Test running status + row.rollout_status = Status.rollout_running() + assert row.rollout_status.is_running() + assert row.rollout_status.code == Status.Code.OK + + # Test finished status + row.rollout_status = Status.rollout_finished() + assert row.rollout_status.is_finished() + assert row.rollout_status.code == Status.Code.FINISHED + + # Test error status + row.rollout_status = Status.rollout_error("Test error") + assert row.rollout_status.is_error() + assert row.rollout_status.code == Status.Code.INTERNAL + + def test_termination_reason_integration(self): + """Test integration of termination reason with status.""" + row = EvaluationRow(messages=[]) + + # Test with termination reason + termination_status = Status.rollout_finished(TerminationReason.CONTROL_PLANE_SIGNAL) + row.rollout_status = termination_status + + assert row.rollout_status.is_finished() + assert row.rollout_status.get_termination_reason() == TerminationReason.CONTROL_PLANE_SIGNAL + + # Test with termination reason and extra info + extra_info = {"steps": 10, "reward": 0.8} + termination_status_with_info = Status.rollout_finished(TerminationReason.USER_STOP, extra_info) + row.rollout_status = termination_status_with_info + + assert row.rollout_status.is_finished() + assert row.rollout_status.get_termination_reason() == TerminationReason.USER_STOP + assert row.rollout_status.get_extra_info() == extra_info + + def test_error_handling_integration(self): + """Test integration of error handling with status.""" + row = EvaluationRow(messages=[]) + + # Test error with metadata + error_info = {"error_code": "E001", "line": 42} + error_status = Status.rollout_error("Runtime error", error_info) + row.rollout_status = error_status + + assert row.rollout_status.is_error() + assert row.rollout_status.get_extra_info() == error_info + assert row.rollout_status.get_termination_reason() is None + + # Test error without metadata + simple_error_status = Status.rollout_error("Simple error") + row.rollout_status = simple_error_status + + assert row.rollout_status.is_error() + assert row.rollout_status.get_extra_info() is None + + def test_status_transitions_integration(self): + """Test that status transitions work correctly in integration.""" + row = EvaluationRow(messages=[]) + + # Start with running + row.rollout_status = Status.rollout_running() + assert row.rollout_status.is_running() + + # Transition to finished + row.rollout_status = Status.rollout_finished() + assert row.rollout_status.is_finished() + assert not row.rollout_status.is_running() + + # Transition to error + row.rollout_status = Status.rollout_error("Something went wrong") + assert row.rollout_status.is_error() + assert not row.rollout_status.is_finished() + + # Transition back to finished + row.rollout_status = Status.rollout_finished() + assert row.rollout_status.is_finished() + assert not row.rollout_status.is_error() + + +class TestAIP193Compliance: + """Test AIP-193 compliance in the migrated code.""" + + def test_error_info_structure_compliance(self): + """Test that ErrorInfo structure follows AIP-193.""" + row = EvaluationRow(messages=[]) + + # Create error status with metadata + error_info = {"error_code": "E001", "timestamp": "2024-01-01"} + error_status = Status.rollout_error("Test error", error_info) + row.rollout_status = error_status + + # Check AIP-193 ErrorInfo structure + assert len(row.rollout_status.details) == 1 + detail = row.rollout_status.details[0] + + # Required fields according to AIP-193 + assert detail["@type"] == "type.googleapis.com/google.rpc.ErrorInfo" + assert "reason" in detail + assert "domain" in detail + assert "metadata" in detail + + # Domain should be service-specific + assert detail["domain"] == "evalprotocol.io" + + # Metadata should contain the error info + assert detail["metadata"] == error_info + + def test_termination_reason_structure_compliance(self): + """Test that termination reason structure follows AIP-193.""" + row = EvaluationRow(messages=[]) + + # Create status with termination reason + termination_status = Status.rollout_finished(TerminationReason.CONTROL_PLANE_SIGNAL) + row.rollout_status = termination_status + + # Check AIP-193 structure + assert len(row.rollout_status.details) == 1 + detail = row.rollout_status.details[0] + + assert detail["@type"] == "type.googleapis.com/google.rpc.ErrorInfo" + assert detail["reason"] == "TERMINATION_REASON" + assert detail["domain"] == "evalprotocol.io" + assert "metadata" in detail + assert "termination_reason" in detail["metadata"] + + def test_multiple_details_compliance(self): + """Test that multiple details follow AIP-193 structure.""" + row = EvaluationRow(messages=[]) + + # Create status with both termination reason and extra info + extra_info = {"steps": 15, "reward": 0.9} + status = Status.rollout_finished(TerminationReason.CONTROL_PLANE_SIGNAL, extra_info) + row.rollout_status = status + + # Should have two details + assert len(row.rollout_status.details) == 2 + + # Both should follow ErrorInfo structure + for detail in row.rollout_status.details: + assert detail["@type"] == "type.googleapis.com/google.rpc.ErrorInfo" + assert "reason" in detail + assert "domain" in detail + assert "metadata" in detail + assert detail["domain"] == "evalprotocol.io" + + def test_status_code_compliance(self): + """Test that status codes follow gRPC standard.""" + row = EvaluationRow(messages=[]) + + # Test standard gRPC codes + statuses = [ + (Status.rollout_running(), Status.Code.OK), + (Status.rollout_finished(), Status.Code.FINISHED), # Custom code + (Status.rollout_error("Test"), Status.Code.INTERNAL), + ] + + for status, expected_code in statuses: + row.rollout_status = status + assert row.rollout_status.code == expected_code + + +if __name__ == "__main__": + pytest.main([__file__]) diff --git a/tests/test_status_migration_integration.py b/tests/test_status_migration_integration.py new file mode 100644 index 00000000..e9beb080 --- /dev/null +++ b/tests/test_status_migration_integration.py @@ -0,0 +1,388 @@ +#!/usr/bin/env python3 +""" +Integration tests for the Status model migration. + +This test suite verifies that: +- All migrated code works correctly with the new Status model +- The field name remains as 'rollout_status' +- All helper methods work as expected +- AIP-193 compliance is maintained +""" + +import pytest +from unittest.mock import Mock, patch +from eval_protocol.models import Status, EvaluationRow, Message +from eval_protocol.types import TerminationReason + + +class TestStatusFieldNamePreservation: + """Test that the field name remains as 'rollout_status'.""" + + def test_evaluation_row_has_rollout_status_field(self): + """Test that EvaluationRow still has rollout_status field.""" + row = EvaluationRow(messages=[]) + + # Should have rollout_status field + assert hasattr(row, "rollout_status") + assert not hasattr(row, "status") + + # Field should be of type Status + assert isinstance(row.rollout_status, Status) + + def test_rollout_status_field_access(self): + """Test direct access to rollout_status field.""" + row = EvaluationRow(messages=[]) + + # Should be able to access directly + assert row.rollout_status.code == Status.Code.OK + assert row.rollout_status.message == "Rollout is running" + + # Should be able to set directly + row.rollout_status = Status.rollout_finished() + assert row.rollout_status.code == Status.Code.FINISHED + + +class TestStatusTransitions: + """Test transitioning between different status states.""" + + def test_running_to_finished_transition(self): + """Test transition from running to finished.""" + row = EvaluationRow(messages=[]) + + # Start with running + assert row.rollout_status.is_running() + assert not row.rollout_status.is_finished() + + # Transition to finished + row.rollout_status = Status.rollout_finished() + assert not row.rollout_status.is_running() + assert row.rollout_status.is_finished() + + def test_running_to_error_transition(self): + """Test transition from running to error.""" + row = EvaluationRow(messages=[]) + + # Start with running + assert row.rollout_status.is_running() + assert not row.rollout_status.is_error() + + # Transition to error + row.rollout_status = Status.rollout_error("Something went wrong") + assert not row.rollout_status.is_running() + assert row.rollout_status.is_error() + + def test_error_to_finished_transition(self): + """Test transition from error to finished.""" + row = EvaluationRow(messages=[]) + + # Start with error + row.rollout_status = Status.rollout_error("Initial error") + assert row.rollout_status.is_error() + + # Transition to finished + row.rollout_status = Status.rollout_finished() + assert not row.rollout_status.is_error() + assert row.rollout_status.is_finished() + + +class TestTerminationReasonIntegration: + """Test integration of termination reason with the new Status model.""" + + def test_termination_reason_in_status_details(self): + """Test that termination reason is properly stored in status details.""" + row = EvaluationRow(messages=[]) + + # Set status with termination reason + termination_status = Status.rollout_finished(TerminationReason.CONTROL_PLANE_SIGNAL) + row.rollout_status = termination_status + + # Should be finished + assert row.rollout_status.is_finished() + + # Should have termination reason in details + assert row.rollout_status.get_termination_reason() == TerminationReason.CONTROL_PLANE_SIGNAL + + # Check details structure + assert len(row.rollout_status.details) == 1 + detail = row.rollout_status.details[0] + assert detail["@type"] == "type.googleapis.com/google.rpc.ErrorInfo" + assert detail["reason"] == "TERMINATION_REASON" + assert detail["domain"] == "evalprotocol.io" + assert detail["metadata"]["termination_reason"] == TerminationReason.CONTROL_PLANE_SIGNAL + + def test_termination_reason_with_extra_info(self): + """Test termination reason with additional extra info.""" + row = EvaluationRow(messages=[]) + + extra_info = {"steps": 10, "reward": 0.8} + termination_status = Status.rollout_finished(TerminationReason.USER_STOP, extra_info) + row.rollout_status = termination_status + + # Should have both termination reason and extra info + assert row.rollout_status.get_termination_reason() == TerminationReason.USER_STOP + assert row.rollout_status.get_extra_info() == extra_info + + # Check details structure + assert len(row.rollout_status.details) == 2 + + # First detail should be termination reason + term_detail = row.rollout_status.details[0] + assert term_detail["reason"] == "TERMINATION_REASON" + + # Second detail should be extra info + extra_detail = row.rollout_status.details[1] + assert extra_detail["reason"] == "EXTRA_INFO" + assert extra_detail["metadata"] == extra_info + + def test_multiple_termination_reasons(self): + """Test handling of multiple termination reasons (edge case).""" + row = EvaluationRow(messages=[]) + + # Create status with duplicate termination reason details + details = [ + { + "@type": "type.googleapis.com/google.rpc.ErrorInfo", + "reason": "TERMINATION_REASON", + "domain": "evalprotocol.io", + "metadata": {"termination_reason": TerminationReason.USER_STOP}, + }, + { + "@type": "type.googleapis.com/google.rpc.ErrorInfo", + "reason": "TERMINATION_REASON", + "domain": "evalprotocol.io", + "metadata": {"termination_reason": TerminationReason.SKIPPABLE_ERROR}, + }, + ] + + status = Status(code=Status.Code.FINISHED, message="Test", details=details) + row.rollout_status = status + + # Should return the first termination reason found + assert row.rollout_status.get_termination_reason() == TerminationReason.USER_STOP + + +class TestErrorHandlingIntegration: + """Test error handling integration with the new Status model.""" + + def test_error_status_with_metadata(self): + """Test error status with structured metadata.""" + row = EvaluationRow(messages=[]) + + error_info = { + "error_code": "E001", + "line": 42, + "function": "test_function", + "timestamp": "2024-01-01T12:00:00Z", + } + + error_status = Status.rollout_error("Runtime error occurred", error_info) + row.rollout_status = error_status + + # Should be error + assert row.rollout_status.is_error() + + # Should have error details + assert row.rollout_status.get_extra_info() == error_info + + # Should not have termination reason + assert row.rollout_status.get_termination_reason() is None + + # Check details structure + assert len(row.rollout_status.details) == 1 + detail = row.rollout_status.details[0] + assert detail["@type"] == "type.googleapis.com/google.rpc.ErrorInfo" + assert detail["reason"] == "EXTRA_INFO" + assert detail["domain"] == "evalprotocol.io" + assert detail["metadata"] == error_info + + def test_error_status_without_metadata(self): + """Test error status without additional metadata.""" + row = EvaluationRow(messages=[]) + + error_status = Status.rollout_error("Simple error message") + row.rollout_status = error_status + + # Should be error + assert row.rollout_status.is_error() + + # Should not have extra info + assert row.rollout_status.get_extra_info() is None + + # Should not have termination reason + assert row.rollout_status.get_termination_reason() is None + + # Should have empty details + assert row.rollout_status.details == [] + + +class TestAIP193Compliance: + """Test AIP-193 compliance features.""" + + def test_error_info_structure(self): + """Test that ErrorInfo follows AIP-193 structure.""" + row = EvaluationRow(messages=[]) + + # Create status with error info + error_info = {"error_code": "E001"} + error_status = Status.rollout_error("Test error", error_info) + row.rollout_status = error_status + + # Check AIP-193 ErrorInfo structure + assert len(row.rollout_status.details) == 1 + detail = row.rollout_status.details[0] + + # Required fields + assert detail["@type"] == "type.googleapis.com/google.rpc.ErrorInfo" + assert "reason" in detail + assert "domain" in detail + assert "metadata" in detail + + # Domain should be service-specific + assert detail["domain"] == "evalprotocol.io" + + # Metadata should contain the error info + assert detail["metadata"] == error_info + + def test_multiple_detail_types(self): + """Test that multiple detail types can coexist.""" + row = EvaluationRow(messages=[]) + + # Create status with both termination reason and extra info + extra_info = {"steps": 15, "reward": 0.9} + status = Status.rollout_finished(TerminationReason.CONTROL_PLANE_SIGNAL, extra_info) + row.rollout_status = status + + # Should have two details + assert len(row.rollout_status.details) == 2 + + # First detail should be termination reason + term_detail = row.rollout_status.details[0] + assert term_detail["reason"] == "TERMINATION_REASON" + + # Second detail should be extra info + extra_detail = row.rollout_status.details[1] + assert extra_detail["reason"] == "EXTRA_INFO" + + # Both should follow ErrorInfo structure + for detail in row.rollout_status.details: + assert detail["@type"] == "type.googleapis.com/google.rpc.ErrorInfo" + assert "reason" in detail + assert "domain" in detail + assert "metadata" in detail + + def test_status_code_mapping(self): + """Test that status codes map correctly to gRPC codes.""" + row = EvaluationRow(messages=[]) + + # Test different status types and their codes + statuses = [ + (Status.rollout_running(), Status.Code.OK), + (Status.rollout_finished(), Status.Code.FINISHED), + (Status.rollout_error("Test"), Status.Code.INTERNAL), + ] + + for status, expected_code in statuses: + row.rollout_status = status + assert row.rollout_status.code == expected_code + + +class TestSerializationAndDeserialization: + """Test that Status can be properly serialized and deserialized.""" + + def test_status_model_dump(self): + """Test that Status can be dumped to dict.""" + row = EvaluationRow(messages=[]) + + # Set a complex status + extra_info = {"steps": 10, "reward": 0.8} + termination_status = Status.rollout_finished(TerminationReason.CONTROL_PLANE_SIGNAL, extra_info) + row.rollout_status = termination_status + + # Dump to dict + status_dict = row.rollout_status.model_dump() + + # Check structure + assert "code" in status_dict + assert "message" in status_dict + assert "details" in status_dict + + # Check values + assert status_dict["code"] == Status.Code.FINISHED + assert status_dict["message"] == "Rollout finished" + assert len(status_dict["details"]) == 2 + + def test_status_model_validate(self): + """Test that Status can be reconstructed from dict.""" + row = EvaluationRow(messages=[]) + + # Set a complex status + extra_info = {"steps": 10, "reward": 0.8} + original_status = Status.rollout_finished(TerminationReason.CONTROL_PLANE_SIGNAL, extra_info) + row.rollout_status = original_status + + # Dump and reconstruct + status_dict = row.rollout_status.model_dump() + reconstructed_status = Status.model_validate(status_dict) + + # Should be equivalent + assert reconstructed_status.code == original_status.code + assert reconstructed_status.message == original_status.message + assert len(reconstructed_status.details) == len(original_status.details) + + # Should preserve functionality + assert reconstructed_status.get_termination_reason() == TerminationReason.CONTROL_PLANE_SIGNAL + assert reconstructed_status.get_extra_info() == extra_info + + +class TestEdgeCases: + """Test edge cases and error conditions.""" + + def test_empty_status_details(self): + """Test Status with empty details.""" + row = EvaluationRow(messages=[]) + + # Create status with empty details + empty_status = Status(code=Status.Code.OK, message="Test", details=[]) + row.rollout_status = empty_status + + # Should handle gracefully + assert row.rollout_status.get_termination_reason() is None + assert row.rollout_status.get_extra_info() is None + + def test_malformed_status_details(self): + """Test Status with malformed details.""" + row = EvaluationRow(messages=[]) + + # Create status with malformed details + malformed_details = [ + {"not_type": "invalid", "reason": "TEST"}, + {"@type": "type.googleapis.com/google.rpc.ErrorInfo", "metadata": {"termination_reason": "test"}}, + ] + + malformed_status = Status(code=Status.Code.OK, message="Test", details=malformed_details) + row.rollout_status = malformed_status + + # Should handle gracefully + assert row.rollout_status.get_termination_reason() is None + assert row.rollout_status.get_extra_info() is None + + def test_large_metadata_handling(self): + """Test Status with large metadata.""" + row = EvaluationRow(messages=[]) + + # Create large metadata + large_metadata = {f"key_{i}": f"value_{i}" for i in range(100)} + + # Should handle large metadata + large_status = Status.rollout_error("Test error", large_metadata) + row.rollout_status = large_status + + # Should preserve all metadata + extra_info = row.rollout_status.get_extra_info() + assert extra_info == large_metadata + assert extra_info is not None + assert len(extra_info) == 100 + + +if __name__ == "__main__": + pytest.main([__file__]) diff --git a/tests/test_status_model.py b/tests/test_status_model.py new file mode 100644 index 00000000..459e8497 --- /dev/null +++ b/tests/test_status_model.py @@ -0,0 +1,350 @@ +#!/usr/bin/env python3 +""" +Tests for the AIP-193 compatible Status model. + +This test suite covers: +- Status code enum values +- Status creation methods +- Helper methods for checking status types +- AIP-193 compliance features +- Migration from RolloutStatus +""" + +import pytest +from eval_protocol.models import Status, EvaluationRow, ErrorInfo +from eval_protocol.types import TerminationReason + + +class TestErrorInfoModel: + """Test the ErrorInfo model.""" + + def test_error_info_creation(self): + """Test creating ErrorInfo instances.""" + error_info = ErrorInfo( + reason="TEST_ERROR", domain="evalprotocol.io", metadata={"error_code": "E001", "line": 42} + ) + + assert error_info.reason == "TEST_ERROR" + assert error_info.domain == "evalprotocol.io" + assert error_info.metadata == {"error_code": "E001", "line": 42} + + def test_error_info_to_aip193_format(self): + """Test conversion to AIP-193 format.""" + error_info = ErrorInfo(reason="TEST_ERROR", domain="evalprotocol.io", metadata={"error_code": "E001"}) + + aip193_format = error_info.to_aip193_format() + + assert aip193_format["@type"] == "type.googleapis.com/google.rpc.ErrorInfo" + assert aip193_format["reason"] == "TEST_ERROR" + assert aip193_format["domain"] == "evalprotocol.io" + assert aip193_format["metadata"] == {"error_code": "E001"} + + def test_error_info_factory_methods(self): + """Test the factory methods for common error types.""" + # Test termination_reason + term_error = ErrorInfo.termination_reason(TerminationReason.CONTROL_PLANE_SIGNAL) + assert term_error.reason == "TERMINATION_REASON" + assert term_error.domain == "evalprotocol.io" + assert term_error.metadata["termination_reason"] == TerminationReason.CONTROL_PLANE_SIGNAL + + # Test extra_info + extra_error = ErrorInfo.extra_info({"steps": 10, "reward": 0.8}) + assert extra_error.reason == "EXTRA_INFO" + assert extra_error.domain == "evalprotocol.io" + assert extra_error.metadata == {"steps": 10, "reward": 0.8} + + +class TestStatusModel: + """Test the AIP-193 compatible Status model.""" + + def test_status_code_enum_values(self): + """Test that Status.Code enum has the correct values.""" + assert Status.Code.OK == 0 + assert Status.Code.CANCELLED == 1 + assert Status.Code.UNKNOWN == 2 + assert Status.Code.INVALID_ARGUMENT == 3 + assert Status.Code.DEADLINE_EXCEEDED == 4 + assert Status.Code.NOT_FOUND == 5 + assert Status.Code.ALREADY_EXISTS == 6 + assert Status.Code.PERMISSION_DENIED == 7 + assert Status.Code.RESOURCE_EXHAUSTED == 8 + assert Status.Code.FAILED_PRECONDITION == 9 + assert Status.Code.ABORTED == 10 + assert Status.Code.OUT_OF_RANGE == 11 + assert Status.Code.UNIMPLEMENTED == 12 + assert Status.Code.INTERNAL == 13 + assert Status.Code.UNAVAILABLE == 14 + assert Status.Code.DATA_LOSS == 15 + assert Status.Code.UNAUTHENTICATED == 16 + assert Status.Code.FINISHED == 100 # Custom code + + def test_status_creation_methods(self): + """Test the convenience methods for creating Status instances.""" + # Test running status + running_status = Status.rollout_running() + assert running_status.code == Status.Code.OK + assert running_status.message == "Rollout is running" + assert running_status.details == [] + + # Test finished status + finished_status = Status.rollout_finished() + assert finished_status.code == Status.Code.FINISHED + assert finished_status.message == "Rollout finished" + assert finished_status.details == [] + + # Test error status + error_status = Status.rollout_error("Something went wrong") + assert error_status.code == Status.Code.INTERNAL + assert error_status.message == "Something went wrong" + assert error_status.details == [] + + # Test error status with extra info + extra_info = {"error_code": "E001", "timestamp": "2024-01-01"} + error_status_with_info = Status.rollout_error("Something went wrong", extra_info) + assert error_status_with_info.code == Status.Code.INTERNAL + assert error_status_with_info.message == "Something went wrong" + assert len(error_status_with_info.details) == 1 + assert error_status_with_info.details[0]["@type"] == "type.googleapis.com/google.rpc.ErrorInfo" + assert error_status_with_info.details[0]["reason"] == "EXTRA_INFO" + assert error_status_with_info.details[0]["domain"] == "evalprotocol.io" + assert error_status_with_info.details[0]["metadata"] == extra_info + + # Test with termination reason + termination_status = Status.rollout_finished(TerminationReason.CONTROL_PLANE_SIGNAL) + assert termination_status.code == Status.Code.FINISHED + assert termination_status.message == "Rollout finished" + assert len(termination_status.details) == 1 + assert termination_status.details[0]["@type"] == "type.googleapis.com/google.rpc.ErrorInfo" + assert termination_status.details[0]["reason"] == "TERMINATION_REASON" + assert termination_status.details[0]["domain"] == "evalprotocol.io" + assert ( + termination_status.details[0]["metadata"]["termination_reason"] == TerminationReason.CONTROL_PLANE_SIGNAL + ) + + # Test with termination reason and extra info + extra_info = {"steps": 10, "reward": 0.8} + termination_status_with_info = Status.rollout_finished(TerminationReason.CONTROL_PLANE_SIGNAL, extra_info) + assert termination_status_with_info.code == Status.Code.FINISHED + assert len(termination_status_with_info.details) == 2 + # First detail should be termination reason + assert termination_status_with_info.details[0]["reason"] == "TERMINATION_REASON" + # Second detail should be extra info + assert termination_status_with_info.details[1]["reason"] == "EXTRA_INFO" + assert termination_status_with_info.details[1]["metadata"] == extra_info + + def test_status_helper_methods(self): + """Test the helper methods for checking status types.""" + # Test is_running + running_status = Status.rollout_running() + assert running_status.is_running() is True + assert running_status.is_finished() is False + assert running_status.is_error() is False + assert running_status.is_stopped() is False + + # Test is_finished + finished_status = Status.rollout_finished() + assert finished_status.is_running() is False + assert finished_status.is_finished() is True + assert finished_status.is_error() is False + assert finished_status.is_stopped() is False + + # Test is_error + error_status = Status.rollout_error("Test error") + assert error_status.is_running() is False + assert error_status.is_finished() is False + assert error_status.is_error() is True + assert error_status.is_stopped() is False + + def test_get_termination_reason(self): + """Test extracting termination reason from status details.""" + + # Status without termination reason + running_status = Status.rollout_running() + assert running_status.get_termination_reason() is None + + # Status with termination reason + termination_status = Status.rollout_finished(TerminationReason.CONTROL_PLANE_SIGNAL) + assert termination_status.get_termination_reason() == TerminationReason.CONTROL_PLANE_SIGNAL + + # Status with termination reason and extra info + extra_info = {"steps": 10} + termination_status_with_info = Status.rollout_finished(TerminationReason.CONTROL_PLANE_SIGNAL, extra_info) + assert termination_status_with_info.get_termination_reason() == TerminationReason.CONTROL_PLANE_SIGNAL + + def test_get_extra_info(self): + """Test extracting extra info from status details.""" + + # Status without extra info + running_status = Status.rollout_running() + assert running_status.get_extra_info() is None + + # Status with only termination reason (no extra info) + termination_status = Status.rollout_finished(TerminationReason.CONTROL_PLANE_SIGNAL) + assert termination_status.get_extra_info() is None + + # Status with extra info + extra_info = {"steps": 10, "reward": 0.8} + error_status = Status.rollout_error("Test error", extra_info) + assert error_status.get_extra_info() == extra_info + + # Status with both termination reason and extra info + termination_status_with_info = Status.rollout_finished(TerminationReason.CONTROL_PLANE_SIGNAL, extra_info) + assert termination_status_with_info.get_extra_info() == extra_info + + def test_aip_193_compliance(self): + """Test that Status model follows AIP-193 standards.""" + # Test ErrorInfo structure + extra_info = {"error_code": "E001"} + error_status = Status.rollout_error("Test error", extra_info) + + assert len(error_status.details) == 1 + detail = error_status.details[0] + + # Check AIP-193 ErrorInfo structure + assert detail["@type"] == "type.googleapis.com/google.rpc.ErrorInfo" + assert detail["reason"] == "EXTRA_INFO" + assert detail["domain"] == "evalprotocol.io" + assert detail["metadata"] == extra_info + + # Test multiple details + termination_status = Status.rollout_finished(TerminationReason.CONTROL_PLANE_SIGNAL, extra_info) + assert len(termination_status.details) == 2 + + # First detail should be termination reason + term_detail = termination_status.details[0] + assert term_detail["@type"] == "type.googleapis.com/google.rpc.ErrorInfo" + assert term_detail["reason"] == "TERMINATION_REASON" + + # Second detail should be extra info + extra_detail = termination_status.details[1] + assert extra_detail["@type"] == "type.googleapis.com/google.rpc.ErrorInfo" + assert extra_detail["reason"] == "EXTRA_INFO" + + def test_status_serialization(self): + """Test that Status can be serialized and deserialized.""" + original_status = Status.rollout_finished(TerminationReason.CONTROL_PLANE_SIGNAL, {"steps": 10}) + + # Test model_dump + status_dict = original_status.model_dump() + assert status_dict["code"] == Status.Code.FINISHED + assert status_dict["message"] == "Rollout finished" + assert len(status_dict["details"]) == 2 + + # Test model_validate + reconstructed_status = Status.model_validate(status_dict) + assert reconstructed_status.code == original_status.code + assert reconstructed_status.message == original_status.message + assert len(reconstructed_status.details) == len(original_status.details) + assert reconstructed_status.get_termination_reason() == TerminationReason.CONTROL_PLANE_SIGNAL + assert reconstructed_status.get_extra_info() == {"steps": 10} + + def test_status_equality(self): + """Test Status equality and comparison.""" + status1 = Status.rollout_running() + status2 = Status.rollout_running() + status3 = Status.rollout_finished() + + # Same values should be equal + assert status1 == status2 + + # Different values should not be equal + assert status1 != status3 + + # Test hash + assert hash(status1) == hash(status2) + assert hash(status1) != hash(status3) + + +class TestStatusMigration: + """Test the migration from RolloutStatus to Status.""" + + def test_evaluation_row_default_status(self): + """Test that EvaluationRow has the correct default status.""" + row = EvaluationRow(messages=[]) + + # Should have rollout_status field (not status) + assert hasattr(row, "rollout_status") + assert not hasattr(row, "status") + + # Default status should be running + assert row.rollout_status.code == Status.Code.OK + assert row.rollout_status.message == "Rollout is running" + assert row.rollout_status.details == [] + + def test_status_transitions(self): + """Test transitioning between different status states.""" + row = EvaluationRow(messages=[]) + + # Start with running + assert row.rollout_status.is_running() + + # Transition to finished + row.rollout_status = Status.rollout_finished() + assert row.rollout_status.is_finished() + assert not row.rollout_status.is_running() + + # Transition to error + row.rollout_status = Status.rollout_error("Something went wrong") + assert row.rollout_status.is_error() + assert not row.rollout_status.is_finished() + + def test_termination_reason_integration(self): + """Test integration of termination reason with status.""" + row = EvaluationRow(messages=[]) + + # Set status with termination reason + termination_status = Status.rollout_finished(TerminationReason.CONTROL_PLANE_SIGNAL, {"steps": 15}) + row.rollout_status = termination_status + + # Should be finished + assert row.rollout_status.is_finished() + + # Should have termination reason + assert row.rollout_status.get_termination_reason() == TerminationReason.CONTROL_PLANE_SIGNAL + + # Should have extra info + extra_info = row.rollout_status.get_extra_info() + assert extra_info == {"steps": 15} + + def test_error_handling_integration(self): + """Test error handling integration with status.""" + row = EvaluationRow(messages=[]) + + # Set error status + error_info = {"error_code": "E001", "line": 42} + error_status = Status.rollout_error("Runtime error occurred", error_info) + row.rollout_status = error_status + + # Should be error + assert row.rollout_status.is_error() + + # Should have error details + assert row.rollout_status.get_extra_info() == error_info + + # Should not have termination reason + assert row.rollout_status.get_termination_reason() is None + + +class TestStatusEdgeCases: + """Test edge cases and error conditions.""" + + def test_empty_details(self): + """Test Status with empty details.""" + status = Status(code=Status.Code.OK, message="Test", details=[]) + assert status.details == [] + assert status.get_termination_reason() is None + assert status.get_extra_info() is None + + def test_large_metadata(self): + """Test Status with large metadata.""" + large_metadata = {f"key_{i}": f"value_{i}" for i in range(100)} + status = Status.rollout_error("Test error", large_metadata) + + # Should handle large metadata + assert status.get_extra_info() == large_metadata + assert len(status.details) == 1 + assert status.details[0]["metadata"] == large_metadata + + +if __name__ == "__main__": + pytest.main([__file__]) diff --git a/vite-app/src/types/eval-protocol.ts b/vite-app/src/types/eval-protocol.ts index e6932b40..05ba8b6c 100644 --- a/vite-app/src/types/eval-protocol.ts +++ b/vite-app/src/types/eval-protocol.ts @@ -81,13 +81,11 @@ export const EvalMetadataSchema = z.object({ passed: z.boolean().optional().describe('Whether the evaluation passed based on the threshold') }); -// Rollout status model (matches Python RolloutStatus) -export const RolloutStatusSchema = z.object({ - status: z - .enum(['running', 'finished', 'error', 'stopped']) - .default('finished') - .describe('Status of the rollout.'), - error_message: z.string().optional().describe('Error message if the rollout failed.') +// AIP-193 compatible Status model (matches Python Status) +export const StatusSchema = z.object({ + code: z.number().describe('The status code (numeric value from google.rpc.Code enum)'), + message: z.string().describe('Developer-facing, human-readable debug message in English'), + details: z.array(z.record(z.string(), z.any())).default([]).describe('Additional error information, each packed in a google.protobuf.Any message format') }); export const ExecutionMetadataSchema = z.object({ @@ -101,7 +99,7 @@ export const EvaluationRowSchema = z.object({ messages: z.array(MessageSchema).describe('List of messages in the conversation/trajectory.'), tools: z.array(z.record(z.string(), z.any())).optional().describe('Available tools/functions that were provided to the agent.'), input_metadata: InputMetadataSchema.describe('Metadata related to the input (dataset info, model config, session data, etc.).'), - rollout_status: RolloutStatusSchema.default({ status: 'finished' }).describe('The status of the rollout.'), + rollout_status: StatusSchema.default({ code: 0, message: 'Rollout is running', details: [] }).describe('The status of the rollout following AIP-193 standards.'), execution_metadata: ExecutionMetadataSchema.optional().describe('Metadata about the execution of the evaluation.'), ground_truth: z.string().optional().describe('Optional ground truth reference for this evaluation.'), evaluation_result: EvaluateResultSchema.optional().describe('The evaluation result for this row/trajectory.'), @@ -171,7 +169,7 @@ export type InputMetadata = z.infer; export type CompletionUsage = z.infer; export type EvalMetadata = z.infer; export type EvaluationRow = z.infer; -export type RolloutStatus = z.infer; +export type Status = z.infer; export type ResourceServerConfig = z.infer; export type EvaluationCriteriaModel = z.infer; export type TaskDefinitionModel = z.infer;