Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 7 additions & 8 deletions eval_protocol/mcp/execution/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from vendor.tau2.data_model.message import AssistantMessage, UserMessage
from vendor.tau2.user.user_simulator import UserSimulator

from ...models import EvaluationRow, InputMetadata, Message, RolloutStatus
from ...models import EvaluationRow, InputMetadata, Message, Status
from ...types import TerminationReason, Trajectory, NonSkippableException

if TYPE_CHECKING:
Expand Down Expand Up @@ -136,15 +136,14 @@ async def _execute_with_semaphore(idx):
}

if trajectory.terminated:
evaluation_row.rollout_status.termination_reason = trajectory.termination_reason
evaluation_row.rollout_status.status = RolloutStatus.Status.FINISHED
# preserve the true error mesage if there are any
extra_info = None
if trajectory.control_plane_summary.get("error_message"):
evaluation_row.rollout_status.extra_info = {
"error_message": trajectory.control_plane_summary.get("error_message")
}
extra_info = {"error_message": trajectory.control_plane_summary.get("error_message")}
evaluation_row.rollout_status = Status.rollout_finished(
termination_reason=trajectory.termination_reason, extra_info=extra_info
)
else:
evaluation_row.rollout_status.status = RolloutStatus.Status.RUNNING
evaluation_row.rollout_status = Status.rollout_running()

return evaluation_row

Expand Down
211 changes: 186 additions & 25 deletions eval_protocol/models.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import os
from datetime import datetime
from enum import Enum
from typing import Any, Dict, List, Literal, Optional, TypedDict, Union
from typing import Any, ClassVar, Dict, List, Literal, Optional, TypedDict, Union

from openai.types import CompletionUsage
from openai.types.chat.chat_completion_message import (
Expand All @@ -15,6 +15,188 @@
from eval_protocol.types import TerminationReason


class ErrorInfo(BaseModel):
"""
AIP-193 ErrorInfo model for structured error details.

This model follows Google's AIP-193 standard for ErrorInfo:
https://google.aip.dev/193#errorinfo

Attributes:
reason (str): A short snake_case description of the cause of the error.
domain (str): The logical grouping to which the reason belongs.
metadata (Dict[str, Any]): Additional dynamic information as context.
"""

# Constants for reason values
REASON_TERMINATION_REASON: ClassVar[str] = "TERMINATION_REASON"
REASON_EXTRA_INFO: ClassVar[str] = "EXTRA_INFO"

# Domain constant
DOMAIN: ClassVar[str] = "evalprotocol.io"

reason: str = Field(..., description="Short snake_case description of the error cause")
domain: str = Field(..., description="Logical grouping for the error reason")
metadata: Dict[str, Any] = Field(default_factory=dict, description="Additional dynamic information as context")

def to_aip193_format(self) -> Dict[str, Any]:
"""Convert to AIP-193 format with @type field."""
return {
"@type": "type.googleapis.com/google.rpc.ErrorInfo",
"reason": self.reason,
"domain": self.domain,
"metadata": self.metadata,
}

@classmethod
def termination_reason(cls, reason: TerminationReason) -> "ErrorInfo":
"""Create an ErrorInfo for termination reason."""
# Convert TerminationReason enum to string if needed
reason_str = reason.value if isinstance(reason, TerminationReason) else reason
return cls(
reason=cls.REASON_TERMINATION_REASON, domain=cls.DOMAIN, metadata={"termination_reason": reason_str}
)

@classmethod
def extra_info(cls, metadata: Dict[str, Any]) -> "ErrorInfo":
"""Create an ErrorInfo for extra information."""
return cls(reason=cls.REASON_EXTRA_INFO, domain=cls.DOMAIN, metadata=metadata)


class Status(BaseModel):
"""
AIP-193 compatible Status model for standardized error responses.

This model follows Google's AIP-193 standard for error handling:
https://google.aip.dev/193

Attributes:
code (int): The status code, must be the numeric value of one of the elements
of google.rpc.Code enum (e.g., 5 for NOT_FOUND).
message (str): Developer-facing, human-readable debug message in English.
details (List[Dict[str, Any]]): Additional error information, each packed in
a google.protobuf.Any message format.
"""

code: "Status.Code" = Field(..., description="The status code from google.rpc.Code enum")
message: str = Field(..., description="Developer-facing, human-readable debug message in English")
details: List[Dict[str, Any]] = Field(
default_factory=list,
description="Additional error information, each packed in a google.protobuf.Any message format",
)

# Convenience constants for common status codes
class Code(int, Enum):
"""Common gRPC status codes as defined in google.rpc.Code"""

OK = 0
CANCELLED = 1
UNKNOWN = 2
INVALID_ARGUMENT = 3
DEADLINE_EXCEEDED = 4
NOT_FOUND = 5
ALREADY_EXISTS = 6
PERMISSION_DENIED = 7
RESOURCE_EXHAUSTED = 8
FAILED_PRECONDITION = 9
ABORTED = 10
OUT_OF_RANGE = 11
UNIMPLEMENTED = 12
INTERNAL = 13
UNAVAILABLE = 14
DATA_LOSS = 15
UNAUTHENTICATED = 16

# Custom codes for rollout states (using higher numbers to avoid conflicts)
FINISHED = 100

@classmethod
def rollout_running(cls) -> "Status":
"""Create a status indicating the rollout is running."""
return cls(code=cls.Code.OK, message="Rollout is running", details=[])

@classmethod
def rollout_finished(
cls,
termination_reason: Optional[TerminationReason] = None,
extra_info: Optional[Dict[str, Any]] = None,
) -> "Status":
"""Create a status indicating the rollout finished."""
details = []
if termination_reason:
details.append(ErrorInfo.termination_reason(termination_reason).to_aip193_format())
if extra_info:
details.append(ErrorInfo.extra_info(extra_info).to_aip193_format())
return cls(code=cls.Code.FINISHED, message="Rollout finished", details=details)

@classmethod
def rollout_error(cls, error_message: str, extra_info: Optional[Dict[str, Any]] = None) -> "Status":
"""Create a status indicating the rollout failed with an error."""
details = []
if extra_info:
details.append(ErrorInfo.extra_info(extra_info).to_aip193_format())
return cls.error(error_message, details)

@classmethod
def error(cls, error_message: str, details: Optional[List[Dict[str, Any]]] = None) -> "Status":
"""Create a status indicating the rollout failed with an error."""
return cls(code=cls.Code.INTERNAL, message=error_message, details=details)

def is_running(self) -> bool:
"""Check if the status indicates the rollout is running."""
return self.code == self.Code.OK and self.message == "Rollout is running"

def is_finished(self) -> bool:
"""Check if the status indicates the rollout finished successfully."""
return self.code == self.Code.FINISHED

def is_error(self) -> bool:
"""Check if the status indicates the rollout failed with an error."""
return self.code == self.Code.INTERNAL

def is_stopped(self) -> bool:
"""Check if the status indicates the rollout was stopped."""
return self.code == self.Code.CANCELLED

def get_termination_reason(self) -> Optional[TerminationReason]:
"""Extract termination reason from details if present."""
for detail in self.details:
metadata = detail.get("metadata", {})
if detail.get("reason") == ErrorInfo.REASON_TERMINATION_REASON and "termination_reason" in metadata:
try:
return TerminationReason.from_str(metadata["termination_reason"])
except ValueError:
# If the reason is not a valid enum value, return None
return None
return None

def get_extra_info(self) -> Optional[Dict[str, Any]]:
"""Extract extra info from details if present."""
for detail in self.details:
metadata = detail.get("metadata", {})
reason = detail.get("reason")
# Skip termination_reason and stopped details, return other error info
if reason in [ErrorInfo.REASON_EXTRA_INFO]:
return metadata
return None

def __hash__(self) -> int:
"""Generate a hash for the Status object."""
# Use a stable hash based on code, message, and details
import hashlib

# Create a stable string representation
hash_data = f"{self.code}:{self.message}:{len(self.details)}"

# Add details content for more uniqueness
for detail in sorted(self.details, key=lambda x: str(x)):
hash_data += f":{str(detail)}"

# Generate hash
hash_obj = hashlib.sha256(hash_data.encode("utf-8"))
return int.from_bytes(hash_obj.digest()[:8], byteorder="big")


class ChatCompletionContentPartTextParam(BaseModel):
text: str = Field(..., description="The text content.")
type: Literal["text"] = Field("text", description="The type of the content part.")
Expand Down Expand Up @@ -289,27 +471,6 @@ class ExecutionMetadata(BaseModel):
)


class RolloutStatus(BaseModel):
"""Status of the rollout."""

"""
running: Unfinished rollout which is still in progress.
finished: Rollout finished.
error: Rollout failed due to unexpected error. The rollout record should be discard.
"""

class Status(str, Enum):
RUNNING = "running"
FINISHED = "finished"
ERROR = "error"

status: Status = Field(Status.RUNNING, description="Status of the rollout.")
termination_reason: Optional[TerminationReason] = Field(
None, description="reason of the rollout status, mapped to values in TerminationReason"
)
extra_info: Optional[Dict[str, Any]] = Field(None, description="Extra information about the rollout status.")


class EvaluationRow(BaseModel):
"""
Unified data structure for a single evaluation unit that contains messages,
Expand All @@ -334,9 +495,9 @@ class EvaluationRow(BaseModel):
description="Metadata related to the input (dataset info, model config, session data, etc.).",
)

rollout_status: RolloutStatus = Field(
default_factory=RolloutStatus,
description="The status of the rollout.",
rollout_status: Status = Field(
default_factory=Status.rollout_running,
description="The status of the rollout following AIP-193 standards.",
)

# Ground truth reference (moved from EvaluateResult to top level)
Expand Down
2 changes: 1 addition & 1 deletion eval_protocol/pytest/plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ def pytest_addoption(parser) -> None:
action="store",
type=int,
default=0,
help=("Failed rollouts (with rollout_status.status == 'error') will be retried up to this many times."),
help=("Failed rollouts (with rollout_status.code indicating error) will be retried up to this many times."),
)
group.addoption(
"--ep-fail-on-max-retry",
Expand Down
12 changes: 5 additions & 7 deletions eval_protocol/pytest/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from typing import Any, Callable, Dict, List, Literal, Optional, Union

from eval_protocol.dataset_logger.dataset_logger import DatasetLogger
from eval_protocol.models import EvalMetadata, EvaluationRow, RolloutStatus
from eval_protocol.models import EvalMetadata, EvaluationRow, Status
from eval_protocol.pytest.rollout_processor import RolloutProcessor
from eval_protocol.pytest.types import (
CompletionParams,
Expand Down Expand Up @@ -282,7 +282,7 @@ async def execute_row_with_backoff(task: asyncio.Task, row: EvaluationRow) -> Ev
try:
# Try original task first
result = await task
result.rollout_status.status = RolloutStatus.Status.FINISHED
result.rollout_status = Status.rollout_finished()
return result
except Exception as e:
# NOTE: we perform these checks because we don't put the backoff decorator on initial batch call. we don't want to retry whole batch if anything fails.
Expand All @@ -295,17 +295,15 @@ async def execute_row_with_backoff(task: asyncio.Task, row: EvaluationRow) -> Ev
# Use shared backoff function for retryable exceptions
try:
result = await execute_row_with_backoff_retry(row)
result.rollout_status.status = RolloutStatus.Status.FINISHED
result.rollout_status = Status.rollout_finished()
return result
except Exception as retry_error:
# Backoff gave up
row.rollout_status.status = RolloutStatus.Status.ERROR
# row.rollout_status.termination_reason = str(retry_error)
row.rollout_status = Status.rollout_error(str(retry_error))
return row
else:
# Non-retryable exception - fail immediately
row.rollout_status.status = RolloutStatus.Status.ERROR
# row.rollout_status.termination_reason = str(e)
row.rollout_status = Status.rollout_error(str(e))
return row

# Process all tasks concurrently with backoff retry
Expand Down
18 changes: 9 additions & 9 deletions tests/test_retry_mechanism.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

import pytest

from eval_protocol.models import EvaluateResult, EvaluationRow, Message, RolloutStatus
from eval_protocol.models import EvaluateResult, EvaluationRow, Message, Status
from eval_protocol.pytest.evaluation_test import evaluation_test
from eval_protocol.pytest.rollout_processor import RolloutProcessor
from eval_protocol.pytest.types import RolloutProcessorConfig
Expand Down Expand Up @@ -95,11 +95,11 @@ async def process_single_row(
def test_retry_mechanism(row: EvaluationRow) -> EvaluationRow:
"""MOCK TEST: Tests that retry mechanism works - one task fails on first attempt, succeeds on retry."""
print(
f"📊 EVALUATED: {row.execution_metadata.rollout_id} ({'SUCCESS' if row.rollout_status.status == 'finished' else 'FAILURE'})"
f"📊 EVALUATED: {row.execution_metadata.rollout_id} ({'SUCCESS' if row.rollout_status.is_finished() else 'FAILURE'})"
)

# Assign a score based on success/failure
score = 1.0 if row.rollout_status.status == "finished" else 0.0
score = 1.0 if row.rollout_status.is_finished() else 0.0
row.evaluation_result = EvaluateResult(score=score)

return row
Expand Down Expand Up @@ -191,9 +191,9 @@ async def process_single_row(row: EvaluationRow) -> EvaluationRow:
def test_fail_fast_exceptions(row: EvaluationRow) -> EvaluationRow:
"""Test that fail-fast exceptions like ValueError are not retried."""
print(
f"📊 EVALUATED: {row.execution_metadata.rollout_id} ({'SUCCESS' if row.rollout_status.status == 'finished' else 'FAILURE'})"
f"📊 EVALUATED: {row.execution_metadata.rollout_id} ({'SUCCESS' if row.rollout_status.is_finished() else 'FAILURE'})"
)
score = 1.0 if row.rollout_status.status == "finished" else 0.0
score = 1.0 if row.rollout_status.is_finished() else 0.0
row.evaluation_result = EvaluateResult(score=score)
return row

Expand Down Expand Up @@ -283,8 +283,8 @@ def custom_http_giveup(e):
def test_custom_giveup_function(row: EvaluationRow) -> EvaluationRow:
"""Test custom giveup function behavior."""
task_content = row.messages[0].content if row.messages else ""
print(f"📊 EVALUATED: {task_content} ({'SUCCESS' if row.rollout_status.status == 'finished' else 'FAILURE'})")
score = 1.0 if row.rollout_status.status == "finished" else 0.0
print(f"📊 EVALUATED: {task_content} ({'SUCCESS' if row.rollout_status.is_finished() else 'FAILURE'})")
score = 1.0 if row.rollout_status.is_finished() else 0.0
row.evaluation_result = EvaluateResult(score=score)
return row

Expand Down Expand Up @@ -368,9 +368,9 @@ def simple_4xx_giveup(e):
def test_simple_giveup_function(row: EvaluationRow) -> EvaluationRow:
"""Test that giveup function prevents retries immediately."""
print(
f"📊 EVALUATED: {row.execution_metadata.rollout_id} ({'SUCCESS' if row.rollout_status.status == 'finished' else 'FAILURE'})"
f"📊 EVALUATED: {row.execution_metadata.rollout_id} ({'SUCCESS' if row.rollout_status.is_finished() else 'FAILURE'})"
)
score = 1.0 if row.rollout_status.status == "finished" else 0.0
score = 1.0 if row.rollout_status.is_finished() else 0.0
row.evaluation_result = EvaluateResult(score=score)
return row

Expand Down
Loading
Loading