diff --git a/eval_protocol/exceptions.py b/eval_protocol/exceptions.py new file mode 100644 index 00000000..3b92c865 --- /dev/null +++ b/eval_protocol/exceptions.py @@ -0,0 +1,177 @@ +""" +Custom exceptions for Eval Protocol that map to gRPC Status codes. + +These exceptions provide a clean way to handle errors and map them to appropriate +Status objects following the AIP-193 standard. +""" + +from typing import Optional + + +class EvalProtocolError(Exception): + """ + Base exception for all Eval Protocol specific errors. + + Maps to Status.Code and can be converted to Status objects for structured logging. + """ + + pass + + +# Standard gRPC status code exceptions +class CancelledError(EvalProtocolError): + """Operation was cancelled (Status.Code.CANCELLED = 1)""" + + status_code = 1 + + +class UnknownError(EvalProtocolError): + """Unknown error occurred (Status.Code.UNKNOWN = 2)""" + + status_code = 2 + + +class InvalidArgumentError(EvalProtocolError): + """Invalid argument provided (Status.Code.INVALID_ARGUMENT = 3)""" + + status_code = 3 + + +class DeadlineExceededError(EvalProtocolError): + """Deadline exceeded (Status.Code.DEADLINE_EXCEEDED = 4)""" + + status_code = 4 + + +class NotFoundError(EvalProtocolError): + """Resource not found (Status.Code.NOT_FOUND = 5)""" + + status_code = 5 + + +class AlreadyExistsError(EvalProtocolError): + """Resource already exists (Status.Code.ALREADY_EXISTS = 6)""" + + status_code = 6 + + +class PermissionDeniedError(EvalProtocolError): + """Permission denied (Status.Code.PERMISSION_DENIED = 7)""" + + status_code = 7 + + +class ResourceExhaustedError(EvalProtocolError): + """Resource exhausted (Status.Code.RESOURCE_EXHAUSTED = 8)""" + + status_code = 8 + + +class FailedPreconditionError(EvalProtocolError): + """Failed precondition (Status.Code.FAILED_PRECONDITION = 9)""" + + status_code = 9 + + +class AbortedError(EvalProtocolError): + """Operation was aborted (Status.Code.ABORTED = 10)""" + + status_code = 10 + + +class OutOfRangeError(EvalProtocolError): + """Value out of range (Status.Code.OUT_OF_RANGE = 11)""" + + status_code = 11 + + +class UnimplementedError(EvalProtocolError): + """Operation is not implemented (Status.Code.UNIMPLEMENTED = 12)""" + + status_code = 12 + + +class InternalError(EvalProtocolError): + """Internal server error (Status.Code.INTERNAL = 13)""" + + status_code = 13 + + +class UnavailableError(EvalProtocolError): + """Service unavailable (Status.Code.UNAVAILABLE = 14)""" + + status_code = 14 + + +class DataLossError(EvalProtocolError): + """Unrecoverable data loss (Status.Code.DATA_LOSS = 15)""" + + status_code = 15 + + +class UnauthenticatedError(EvalProtocolError): + """Request lacks valid authentication (Status.Code.UNAUTHENTICATED = 16)""" + + status_code = 16 + + +# Custom EP exceptions +class RolloutFinishedError(EvalProtocolError): + """Rollout completed successfully (Status.Code.FINISHED = 100)""" + + status_code = 100 + + +class RolloutRunningError(EvalProtocolError): + """Rollout is still running (Status.Code.RUNNING = 101)""" + + status_code = 101 + + +class ScoreInvalidError(EvalProtocolError): + """Score is invalid (Status.Code.SCORE_INVALID = 102)""" + + status_code = 102 + + +# Convenience mapping from status codes to exception classes +# Only actual error conditions should raise exceptions +STATUS_CODE_TO_EXCEPTION = { + 0: None, # OK - success, no exception + 1: CancelledError, + 2: UnknownError, + 3: InvalidArgumentError, + 4: DeadlineExceededError, + 5: NotFoundError, + 6: AlreadyExistsError, + 7: PermissionDeniedError, + 8: ResourceExhaustedError, + 9: FailedPreconditionError, + 10: AbortedError, + 11: OutOfRangeError, + 12: UnimplementedError, + 13: InternalError, + 14: UnavailableError, + 15: DataLossError, + 16: UnauthenticatedError, + 100: None, # FINISHED - success, no exception + 101: None, # RUNNING - in progress, no exception + 102: None, # SCORE_INVALID - success, no exception +} + + +def exception_for_status_code(code: int, message: str = "") -> Optional[EvalProtocolError]: + """ + Create an exception instance for a given status code. + + Args: + code: Status code from Status.Code enum + message: Optional error message to include in the exception + + Returns: + Exception instance or None if code is OK (0) + """ + exception_class = STATUS_CODE_TO_EXCEPTION.get(code) + if exception_class is None: + return None + return exception_class(message) if message else exception_class() diff --git a/eval_protocol/models.py b/eval_protocol/models.py index d3fe0f61..7180ed72 100644 --- a/eval_protocol/models.py +++ b/eval_protocol/models.py @@ -1,4 +1,6 @@ import os +import logging +import importlib from datetime import datetime from enum import Enum from typing import Any, ClassVar, Dict, List, Literal, Optional, TypedDict, Union @@ -131,6 +133,13 @@ def eval_finished(cls) -> "Status": """Create a status indicating the evaluation finished.""" return cls(code=cls.Code.FINISHED, message="Evaluation finished", details=[]) + @staticmethod + def _build_details_with_extra_info(extra_info: Optional[Dict[str, Any]]) -> List[Dict[str, Any]]: + """Helper to build details list from extra_info.""" + if extra_info: + return [ErrorInfo.extra_info(extra_info).to_aip193_format()] + return [] + @classmethod def aborted(cls, message: str, details: Optional[List[Dict[str, Any]]] = None) -> "Status": """Create a status indicating the evaluation was aborted.""" @@ -155,19 +164,202 @@ def finished(cls, message: str, details: Optional[List[Dict[str, Any]]] = None) """Create a status indicating the rollout finished.""" return cls(code=cls.Code.FINISHED, message=message, details=details or []) + # Error methods organized by Status.Code enum values (1-16) + + # CANCELLED = 1 + @classmethod + def rollout_cancelled_error(cls, error_message: str, extra_info: Optional[Dict[str, Any]] = None) -> "Status": + """Create a status indicating the rollout was cancelled.""" + return cls.cancelled_error(error_message, cls._build_details_with_extra_info(extra_info)) + + @classmethod + def cancelled_error(cls, error_message: str, details: Optional[List[Dict[str, Any]]] = None) -> "Status": + """Create a status indicating the operation was cancelled.""" + return cls(code=cls.Code.CANCELLED, message=error_message, details=details or []) + + # UNKNOWN = 2 + @classmethod + def rollout_unknown_error(cls, error_message: str, extra_info: Optional[Dict[str, Any]] = None) -> "Status": + """Create a status indicating the rollout failed with an unknown error.""" + return cls.unknown_error(error_message, cls._build_details_with_extra_info(extra_info)) + + @classmethod + def unknown_error(cls, error_message: str, details: Optional[List[Dict[str, Any]]] = None) -> "Status": + """Create a status indicating an unknown error occurred.""" + return cls(code=cls.Code.UNKNOWN, message=error_message, details=details or []) + + # INVALID_ARGUMENT = 3 + @classmethod + def rollout_invalid_argument_error( + cls, error_message: str, extra_info: Optional[Dict[str, Any]] = None + ) -> "Status": + """Create a status indicating the rollout failed with an invalid argument error.""" + return cls.invalid_argument_error(error_message, cls._build_details_with_extra_info(extra_info)) + + @classmethod + def invalid_argument_error(cls, error_message: str, details: Optional[List[Dict[str, Any]]] = None) -> "Status": + """Create a status indicating an invalid argument error occurred.""" + return cls(code=cls.Code.INVALID_ARGUMENT, message=error_message, details=details or []) + + # DEADLINE_EXCEEDED = 4 + @classmethod + def rollout_deadline_exceeded_error( + cls, error_message: str, extra_info: Optional[Dict[str, Any]] = None + ) -> "Status": + """Create a status indicating the rollout failed with a deadline exceeded error.""" + return cls.deadline_exceeded_error(error_message, cls._build_details_with_extra_info(extra_info)) + + @classmethod + def deadline_exceeded_error(cls, error_message: str, details: Optional[List[Dict[str, Any]]] = None) -> "Status": + """Create a status indicating a deadline exceeded error occurred.""" + return cls(code=cls.Code.DEADLINE_EXCEEDED, message=error_message, details=details or []) + + # NOT_FOUND = 5 + @classmethod + def rollout_not_found_error(cls, error_message: str, extra_info: Optional[Dict[str, Any]] = None) -> "Status": + """Create a status indicating the rollout failed with a not found error.""" + return cls.not_found_error(error_message, cls._build_details_with_extra_info(extra_info)) + + @classmethod + def not_found_error(cls, error_message: str, details: Optional[List[Dict[str, Any]]] = None) -> "Status": + """Create a status indicating a not found error occurred.""" + return cls(code=cls.Code.NOT_FOUND, message=error_message, details=details or []) + + # ALREADY_EXISTS = 6 + @classmethod + def rollout_already_exists_error(cls, error_message: str, extra_info: Optional[Dict[str, Any]] = None) -> "Status": + """Create a status indicating the rollout failed with an already exists error.""" + return cls.already_exists_error(error_message, cls._build_details_with_extra_info(extra_info)) + + @classmethod + def already_exists_error(cls, error_message: str, details: Optional[List[Dict[str, Any]]] = None) -> "Status": + """Create a status indicating an already exists error occurred.""" + return cls(code=cls.Code.ALREADY_EXISTS, message=error_message, details=details or []) + + # PERMISSION_DENIED = 7 + @classmethod + def rollout_permission_denied_error( + cls, error_message: str, extra_info: Optional[Dict[str, Any]] = None + ) -> "Status": + """Create a status indicating the rollout failed with a permission denied error.""" + return cls.permission_denied_error(error_message, cls._build_details_with_extra_info(extra_info)) + + @classmethod + def permission_denied_error(cls, error_message: str, details: Optional[List[Dict[str, Any]]] = None) -> "Status": + """Create a status indicating a permission denied error occurred.""" + return cls(code=cls.Code.PERMISSION_DENIED, message=error_message, details=details or []) + + # RESOURCE_EXHAUSTED = 8 + @classmethod + def rollout_resource_exhausted_error( + cls, error_message: str, extra_info: Optional[Dict[str, Any]] = None + ) -> "Status": + """Create a status indicating the rollout failed with a resource exhausted error.""" + return cls.resource_exhausted_error(error_message, cls._build_details_with_extra_info(extra_info)) + + @classmethod + def resource_exhausted_error(cls, error_message: str, details: Optional[List[Dict[str, Any]]] = None) -> "Status": + """Create a status indicating a resource exhausted error occurred.""" + return cls(code=cls.Code.RESOURCE_EXHAUSTED, message=error_message, details=details or []) + + # FAILED_PRECONDITION = 9 + @classmethod + def rollout_failed_precondition_error( + cls, error_message: str, extra_info: Optional[Dict[str, Any]] = None + ) -> "Status": + """Create a status indicating the rollout failed with a failed precondition error.""" + return cls.failed_precondition_error(error_message, cls._build_details_with_extra_info(extra_info)) + + @classmethod + def failed_precondition_error(cls, error_message: str, details: Optional[List[Dict[str, Any]]] = None) -> "Status": + """Create a status indicating a failed precondition error occurred.""" + return cls(code=cls.Code.FAILED_PRECONDITION, message=error_message, details=details or []) + + # ABORTED = 10 + @classmethod + def rollout_aborted_error(cls, error_message: str, extra_info: Optional[Dict[str, Any]] = None) -> "Status": + """Create a status indicating the rollout was aborted.""" + return cls.aborted(error_message, cls._build_details_with_extra_info(extra_info)) + + # OUT_OF_RANGE = 11 + @classmethod + def rollout_out_of_range_error(cls, error_message: str, extra_info: Optional[Dict[str, Any]] = None) -> "Status": + """Create a status indicating the rollout failed with an out of range error.""" + return cls.out_of_range_error(error_message, cls._build_details_with_extra_info(extra_info)) + + @classmethod + def out_of_range_error(cls, error_message: str, details: Optional[List[Dict[str, Any]]] = None) -> "Status": + """Create a status indicating an out of range error occurred.""" + return cls(code=cls.Code.OUT_OF_RANGE, message=error_message, details=details or []) + + # UNIMPLEMENTED = 12 + @classmethod + def rollout_unimplemented_error(cls, error_message: str, extra_info: Optional[Dict[str, Any]] = None) -> "Status": + """Create a status indicating the rollout failed with an unimplemented error.""" + return cls.unimplemented_error(error_message, cls._build_details_with_extra_info(extra_info)) + + @classmethod + def unimplemented_error(cls, error_message: str, details: Optional[List[Dict[str, Any]]] = None) -> "Status": + """Create a status indicating an unimplemented error occurred.""" + return cls(code=cls.Code.UNIMPLEMENTED, message=error_message, details=details or []) + + # INTERNAL = 13 + @classmethod + def rollout_internal_error(cls, error_message: str, extra_info: Optional[Dict[str, Any]] = None) -> "Status": + """Create a status indicating the rollout failed with an internal error.""" + return cls.internal_error(error_message, cls._build_details_with_extra_info(extra_info)) + + @classmethod + def internal_error(cls, error_message: str, details: Optional[List[Dict[str, Any]]] = None) -> "Status": + """Create a status indicating an internal error occurred.""" + return cls(code=cls.Code.INTERNAL, message=error_message, details=details or []) + + # For backwards compatibility @classmethod def rollout_error(cls, error_message: str, extra_info: Optional[Dict[str, Any]] = None) -> "Status": """Create a status indicating the rollout failed with an error.""" - details = [] - if extra_info: - details.append(ErrorInfo.extra_info(extra_info).to_aip193_format()) - return cls.error(error_message, details) + return cls.internal_error(error_message, cls._build_details_with_extra_info(extra_info)) @classmethod def error(cls, error_message: str, details: Optional[List[Dict[str, Any]]] = None) -> "Status": - """Create a status indicating the rollout failed with an error.""" + """Create a status indicating an error occurred.""" return cls(code=cls.Code.INTERNAL, message=error_message, details=details or []) + # UNAVAILABLE = 14 + @classmethod + def rollout_unavailable_error(cls, error_message: str, extra_info: Optional[Dict[str, Any]] = None) -> "Status": + """Create a status indicating the rollout failed with an unavailable error.""" + return cls.unavailable_error(error_message, cls._build_details_with_extra_info(extra_info)) + + @classmethod + def unavailable_error(cls, error_message: str, details: Optional[List[Dict[str, Any]]] = None) -> "Status": + """Create a status indicating an unavailable error occurred.""" + return cls(code=cls.Code.UNAVAILABLE, message=error_message, details=details or []) + + # DATA_LOSS = 15 + @classmethod + def rollout_data_loss_error(cls, error_message: str, extra_info: Optional[Dict[str, Any]] = None) -> "Status": + """Create a status indicating the rollout failed with a data loss error.""" + return cls.data_loss_error(error_message, cls._build_details_with_extra_info(extra_info)) + + @classmethod + def data_loss_error(cls, error_message: str, details: Optional[List[Dict[str, Any]]] = None) -> "Status": + """Create a status indicating a data loss error occurred.""" + return cls(code=cls.Code.DATA_LOSS, message=error_message, details=details or []) + + # UNAUTHENTICATED = 16 + @classmethod + def rollout_unauthenticated_error( + cls, error_message: str, extra_info: Optional[Dict[str, Any]] = None + ) -> "Status": + """Create a status indicating the rollout failed with an unauthenticated error.""" + return cls.unauthenticated_error(error_message, cls._build_details_with_extra_info(extra_info)) + + @classmethod + def unauthenticated_error(cls, error_message: str, details: Optional[List[Dict[str, Any]]] = None) -> "Status": + """Create a status indicating an unauthenticated error occurred.""" + return cls(code=cls.Code.UNAUTHENTICATED, message=error_message, details=details or []) + @classmethod def score_invalid( cls, message: str = "Score is invalid", details: Optional[List[Dict[str, Any]]] = None diff --git a/eval_protocol/pytest/evaluation_test_utils.py b/eval_protocol/pytest/evaluation_test_utils.py index c582d4be..26b0d799 100644 --- a/eval_protocol/pytest/evaluation_test_utils.py +++ b/eval_protocol/pytest/evaluation_test_utils.py @@ -398,7 +398,7 @@ async def execute_row_with_backoff(task: asyncio.Task[EvaluationRow], row: Evalu else: # Non-retryable exception - fail immediately logging.error(f"❌ Rollout failed (non-retryable error encountered): {repr(e)}") - row.rollout_status = Status.rollout_error(repr(e)) + row.rollout_status = Status.rollout_error(str(e)) return row async def execute_row_with_backoff_and_log( diff --git a/eval_protocol/pytest/exception_config.py b/eval_protocol/pytest/exception_config.py index c8ccaf8e..e4bb1b7c 100644 --- a/eval_protocol/pytest/exception_config.py +++ b/eval_protocol/pytest/exception_config.py @@ -12,6 +12,8 @@ import requests import httpx +import eval_protocol.exceptions + # Default exceptions that should be retried with backoff DEFAULT_RETRYABLE_EXCEPTIONS: Set[Type[Exception]] = { @@ -29,13 +31,22 @@ httpx.TimeoutException, httpx.NetworkError, httpx.RemoteProtocolError, + # LiteLLM library exceptions litellm.exceptions.RateLimitError, litellm.exceptions.InternalServerError, litellm.exceptions.Timeout, litellm.exceptions.NotFoundError, - litellm.exceptions.BadRequestError, litellm.exceptions.ServiceUnavailableError, litellm.exceptions.APIError, + litellm.exceptions.BadRequestError, + # Eval Protocol exceptions + eval_protocol.exceptions.UnknownError, + eval_protocol.exceptions.DeadlineExceededError, + eval_protocol.exceptions.NotFoundError, + eval_protocol.exceptions.PermissionDeniedError, + eval_protocol.exceptions.UnavailableError, + eval_protocol.exceptions.UnauthenticatedError, + eval_protocol.exceptions.ResourceExhaustedError, } diff --git a/eval_protocol/pytest/remote_rollout_processor.py b/eval_protocol/pytest/remote_rollout_processor.py index 68d47dcd..dd179e34 100644 --- a/eval_protocol/pytest/remote_rollout_processor.py +++ b/eval_protocol/pytest/remote_rollout_processor.py @@ -10,6 +10,7 @@ DataLoaderConfig, ) from eval_protocol.adapters.fireworks_tracing import FireworksTracingAdapter +from eval_protocol.exceptions import exception_for_status_code from .rollout_processor import RolloutProcessor from .types import RolloutProcessorConfig @@ -93,17 +94,11 @@ async def _process_row(row: EvaluationRow) -> EvaluationRow: def _post_init() -> None: url = f"{remote_base_url}/init" try: - r = requests.post(url, json=init_payload.model_dump(), timeout=30) + r = requests.post(url, json=init_payload.model_dump(), timeout=300) r.raise_for_status() except requests.exceptions.Timeout: raise TimeoutError( - "The /init endpoint timed out after 30 seconds. " - "CRITICAL: The /init endpoint must return immediately (within 30s) and NOT block on rollout execution. " - "Your remote server should:\n" - "1. Accept the /init request and return a 200 response immediately\n" - "2. Process the actual rollout asynchronously in the background\n" - "3. Use the /status endpoint to report progress\n" - "For Python/Node.js: Start a separate process per rollout to avoid blocking the /init response." + f"The /init endpoint tried {url} with {init_payload.model_dump()} but timed out after 300 seconds." ) await asyncio.to_thread(_post_init) @@ -138,9 +133,9 @@ def _get_status() -> Dict[str, Any]: # For all other exceptions, raise them raise - # Search Fireworks tracing logs for completion - completed_logs = self._tracing_adapter.search_logs( - tags=[f"rollout_id:{row.execution_metadata.rollout_id}"] + # Search Fireworks tracing logs for completion (run in thread to avoid blocking event loop) + completed_logs = await asyncio.to_thread( + self._tracing_adapter.search_logs, tags=[f"rollout_id:{row.execution_metadata.rollout_id}"] ) # Filter for logs that actually have status information status_logs = [] @@ -166,6 +161,11 @@ def _get_status() -> Dict[str, Any]: f"Found Fireworks log for rollout {row.execution_metadata.rollout_id} with status code {status_code}" ) + # Create and raise exception if appropriate, preserving original message + exception = exception_for_status_code(status_code, status_message) + if exception is not None: + raise exception + row.rollout_status = Status( code=Status.Code(status_code), message=status_message, @@ -181,7 +181,7 @@ def _get_status() -> Dict[str, Any]: f"Loop completed without breaking for {row.execution_metadata.rollout_id}, which means we timed out" ) # Loop completed without breaking, which means we timed out - row.rollout_status = Status.rollout_error( + row.rollout_status = Status.rollout_deadline_exceeded_error( f"Rollout {row.execution_metadata.rollout_id} timed out after {timeout_seconds} seconds" ) diff --git a/eval_protocol/pytest/tracing_utils.py b/eval_protocol/pytest/tracing_utils.py index 27f8d14c..6ea69371 100644 --- a/eval_protocol/pytest/tracing_utils.py +++ b/eval_protocol/pytest/tracing_utils.py @@ -151,14 +151,14 @@ def update_row_with_remote_trace( output_rows: List[EvaluationRow] = [r for result in results for r in result.rows] if len(output_rows) == 0: # Fallback to original row if no remote data found - row.rollout_status = Status(code=Status.Code.NOT_FOUND, message="No remote data found for rollout") + row.rollout_status = Status.rollout_not_found_error("No remote data found for rollout") return None elif len(output_rows) == 1: # Return the remote row remote_row = output_rows[0] # if the remote_row has the same number of messages as the original row, something went wrong if len(remote_row.messages) == len(row.messages): - row.rollout_status = Status.rollout_error( + row.rollout_status = Status.rollout_internal_error( "Rollout finished with the same number of messages as the original row" ) return None diff --git a/eval_protocol/quickstart/svg_agent/vercel_svg_server/api/init.py b/eval_protocol/quickstart/svg_agent/vercel_svg_server/api/init.py index d376880e..ffd8b9ea 100644 --- a/eval_protocol/quickstart/svg_agent/vercel_svg_server/api/init.py +++ b/eval_protocol/quickstart/svg_agent/vercel_svg_server/api/init.py @@ -9,8 +9,10 @@ import os import logging import sys -from http.server import BaseHTTPRequestHandler +import asyncio +from flask import Flask, request, jsonify from openai import OpenAI +import openai from dotenv import load_dotenv from eval_protocol import Status, InitRequest, FireworksTracingHttpHandler, RolloutIdFilter @@ -44,119 +46,157 @@ def filter(self, record: logging.LogRecord) -> bool: # Attach Fireworks tracing handler to root logger (non-stream HTTP sink) root_logger.addHandler(FireworksTracingHttpHandler()) - -class handler(BaseHTTPRequestHandler): - def do_POST(self): - try: - # Read and parse request body - content_length = int(self.headers.get("Content-Length", 0)) - request_body = self.rfile.read(content_length).decode("utf-8") - request_data = json.loads(request_body) - - # Parse as InitRequest - req = InitRequest(**request_data) - - # Attach rollout_id filter to logger - logger = logging.getLogger(f"{__name__}.{req.metadata.rollout_id}") - logger.addFilter(RolloutIdFilter(req.metadata.rollout_id)) - - # Validate required fields - if not req.messages: - error_msg = "messages is required" - logger.error(error_msg, extra={"status": Status.rollout_error(error_msg)}) - self._send_error(400, error_msg) - return - - model = req.completion_params.get("model") - if model and isinstance(model, str) and model.startswith("fireworks_ai/"): - model = model[len("fireworks_ai/") :] - - # Prepare completion arguments - completion_kwargs = { - "messages": req.messages, - "model": model, - "temperature": req.completion_params.get("temperature"), - "max_tokens": req.completion_params.get("max_tokens"), - } - - # Add tools if present - if req.tools: - completion_kwargs["tools"] = req.tools - - # Get API key (prefer request api_key, fallback to environment) - api_key = req.api_key or os.environ.get("FIREWORKS_API_KEY") - if not api_key: - error_msg = "API key not provided in request or FIREWORKS_API_KEY environment variable" - logger.error(error_msg, extra={"status": Status.rollout_error(error_msg)}) - self._send_error(500, error_msg) - return - - # Create OpenAI client - client = OpenAI(base_url=req.model_base_url, api_key=api_key) - - logger.info(f"Sending completion request to model {req.completion_params.get('model')}") - - # Make the model call - completion = client.chat.completions.create(**completion_kwargs) - - logger.info(f"Completed response: {completion}") - - # Log completion status - logger.info(f"Rollout {req.metadata.rollout_id} completed", extra={"status": Status.rollout_finished()}) - - # Return the completion response - response_data = { - "status": "completed", - "rollout_id": req.metadata.rollout_id, - "choices": [ - { - "message": { - "role": completion.choices[0].message.role, - "content": completion.choices[0].message.content, - } - } - ], - } - - self._send_json_response(200, response_data) - - except Exception as e: - # Log error if we have the request context - if "req" in locals() and "logger" in locals(): - logger.error(f"❌ Error in rollout {req.metadata.rollout_id}: {e}") - logger.error(str(e), extra={"status": Status.rollout_error(str(e))}) - - self._send_error(500, str(e)) - - def do_GET(self): - """Health check endpoint""" - self._send_json_response( - 200, - { - "status": "ok", - "message": "SVGBench Vercel Serverless Function", - "endpoints": {"POST /": "Process SVGBench evaluation requests"}, - }, +# Create Flask app +app = Flask(__name__) + + +async def execute_rollout_background(req: InitRequest, api_key: str): + """Execute the OpenAI completion in background and log results""" + # Attach rollout_id filter to logger + logger = logging.getLogger(f"{__name__}.{req.metadata.rollout_id}") + logger.addFilter(RolloutIdFilter(req.metadata.rollout_id)) + + model = req.completion_params.get("model") + # Uncomment if you need to strip fireworks_ai/ prefix + # if model and isinstance(model, str) and model.startswith("fireworks_ai/"): + # model = model[len("fireworks_ai/"):] + + # Prepare completion arguments + completion_kwargs = { + "messages": req.messages, + # "messages": [{"role": "user", "content": "Hello, how are you?"}], + "model": model, + "temperature": req.completion_params.get("temperature"), + "max_tokens": req.completion_params.get("max_tokens"), + } + + # Add tools if present + if req.tools: + completion_kwargs["tools"] = req.tools + + logger.info( + f"DEBUG: {req.model_base_url}, COMPLETION_KWARGS: {completion_kwargs}, API_KEY: {api_key}, MODEL: {model}" + ) + + # Create AsyncOpenAI client + # client = AsyncOpenAI(base_url=req.model_base_url, api_key=api_key) + client = OpenAI(base_url=req.model_base_url, api_key=api_key) + + logger.info(f"Sending completion request to model {model}") + + # Make the async model call with timeout + import time + + logger.info(f"timing start: {time.time()}") + + try: + completion = client.chat.completions.create(**completion_kwargs) + except ( + openai.AuthenticationError, + openai.PermissionDeniedError, + ) as e: + # These errors should be logged and will be retried by RemoteRolloutProcessor + logger.error( + f"Rollout {req.metadata.rollout_id} failed: {e}", + extra={"status": Status.rollout_permission_denied_error(str(e))}, ) - - def do_OPTIONS(self): - """Handle CORS preflight requests""" - self.send_response(200) - self.send_header("Access-Control-Allow-Origin", "*") - self.send_header("Access-Control-Allow-Methods", "POST, GET, OPTIONS") - self.send_header("Access-Control-Allow-Headers", "Content-Type") - self.end_headers() - - def _send_json_response(self, status_code: int, data: dict): - """Send a JSON response""" - self.send_response(status_code) - self.send_header("Content-Type", "application/json") - self.send_header("Access-Control-Allow-Origin", "*") - self.send_header("Access-Control-Allow-Methods", "POST, GET, OPTIONS") - self.send_header("Access-Control-Allow-Headers", "Content-Type") - self.end_headers() - self.wfile.write(json.dumps(data).encode("utf-8")) - - def _send_error(self, status_code: int, message: str): - """Send an error response""" - self._send_json_response(status_code, {"error": message}) + return + except openai.NotFoundError as e: + logger.error( + f"Rollout {req.metadata.rollout_id} failed: {e}", extra={"status": Status.rollout_not_found_error(str(e))} + ) + return + except openai.RateLimitError as e: + logger.error( + f"Rollout {req.metadata.rollout_id} failed: {e}", + extra={"status": Status.rollout_resource_exhausted_error(str(e))}, + ) + return + except Exception as e: + # Non-OpenAI errors (shouldn't normally happen but catch anyway) + logger.error( + f"Rollout {req.metadata.rollout_id} failed with unexpected error: {e}", + extra={"status": Status.rollout_internal_error(str(e))}, + ) + return + + logger.info(f"Completed response: {completion}") + logger.info(f"timing end: {time.time()}") + # Log successful completion - THIS IS WHAT RemoteRolloutProcessor POLLS FOR + logger.info(f"Rollout {req.metadata.rollout_id} completed", extra={"status": Status.rollout_finished()}) + + +@app.route("/init", methods=["POST"]) +async def init(): + try: + # Parse as InitRequest + req = InitRequest(**request.get_json()) + + # Create logger for immediate validation logging + logger = logging.getLogger(f"{__name__}.{req.metadata.rollout_id}") + logger.addFilter(RolloutIdFilter(req.metadata.rollout_id)) + + # Validate required fields + if not req.messages: + error_msg = "messages is required" + logger.error(error_msg, extra={"status": Status.rollout_internal_error(error_msg)}) + return jsonify({"error": error_msg}), 400 + + # Get API key (prefer request api_key, fallback to environment) + if req.api_key: + logger.info("Using API key from request") + api_key = req.api_key + elif os.environ.get("FIREWORKS_API_KEY"): + logger.info("Using API key from environment") + api_key = os.environ.get("FIREWORKS_API_KEY") + else: + error_msg = "API key not provided in request or environment variable" + logger.error(error_msg, extra={"status": Status.rollout_internal_error(error_msg)}) + return jsonify({"error": error_msg}), 401 + + # 🔥 FIRE: Return immediately with acceptance (within 30s requirement) + response_data = { + "status": "accepted", + "rollout_id": req.metadata.rollout_id, + "message": "Rollout processing started", + } + + # Fire and forget: Execute rollout asynchronously + asyncio.create_task(execute_rollout_background(req, api_key or "")) + + return jsonify(response_data), 200 + + except Exception as e: + # For request parsing errors, return error immediately (don't retry) + return jsonify({"error": f"Request parsing error: {str(e)}"}), 400 + + +@app.route("/", methods=["GET"]) +def health_check(): + """Health check endpoint""" + return jsonify( + { + "status": "ok", + "message": "SVGBench Vercel Serverless Function", + "endpoints": {"POST /": "Process SVGBench evaluation requests"}, + } + ) + + +@app.route("/", methods=["OPTIONS"]) +def options_handler(): + """Handle CORS preflight requests""" + response = jsonify({}) + response.headers["Access-Control-Allow-Origin"] = "*" + response.headers["Access-Control-Allow-Methods"] = "POST, GET, OPTIONS" + response.headers["Access-Control-Allow-Headers"] = "Content-Type" + return response + + +# Add CORS headers to all responses +@app.after_request +def add_cors_headers(response): + response.headers["Access-Control-Allow-Origin"] = "*" + response.headers["Access-Control-Allow-Methods"] = "POST, GET, OPTIONS" + response.headers["Access-Control-Allow-Headers"] = "Content-Type" + return response diff --git a/eval_protocol/quickstart/svg_agent/vercel_svg_server/requirements.txt b/eval_protocol/quickstart/svg_agent/vercel_svg_server/requirements.txt index dadd2db1..f4ce92fc 100644 --- a/eval_protocol/quickstart/svg_agent/vercel_svg_server/requirements.txt +++ b/eval_protocol/quickstart/svg_agent/vercel_svg_server/requirements.txt @@ -1,3 +1,4 @@ openai>=1.0.0 python-dotenv>=0.19.0 -eval_protocol>=0.2.58 +eval_protocol>=0.2.71 +Flask[async]==3.0.3 diff --git a/eval_protocol/quickstart/svg_agent/vercel_svg_server/vercel.json b/eval_protocol/quickstart/svg_agent/vercel_svg_server/vercel.json index 112be6e9..4291b8c1 100644 --- a/eval_protocol/quickstart/svg_agent/vercel_svg_server/vercel.json +++ b/eval_protocol/quickstart/svg_agent/vercel_svg_server/vercel.json @@ -1,3 +1,5 @@ { - "redirects": [{ "source": "/init", "destination": "/api/init" }] + "rewrites": [ + { "source": "/(.*)", "destination": "/api/init" } + ] } diff --git a/tests/remote_server/test_remote_fireworks_propagate_status.py b/tests/remote_server/test_remote_fireworks_propagate_status.py index e415ed61..7c05172f 100644 --- a/tests/remote_server/test_remote_fireworks_propagate_status.py +++ b/tests/remote_server/test_remote_fireworks_propagate_status.py @@ -1,15 +1,4 @@ -# MANUAL SERVER STARTUP REQUIRED: -# -# For Python server testing, start: -# python -m tests.remote_server.remote_server (runs on http://127.0.0.1:3000) -# -# For TypeScript server testing, start: -# cd tests/remote_server/typescript-server -# npm install -# npm start -# -# The TypeScript server should be running on http://127.0.0.1:3000 -# You only need to start one of the servers! +# AUTO SERVER STARTUP: Server is automatically started and stopped by the test import subprocess import socket diff --git a/tests/test_exceptions.py b/tests/test_exceptions.py new file mode 100644 index 00000000..a1fcdeb6 --- /dev/null +++ b/tests/test_exceptions.py @@ -0,0 +1,371 @@ +""" +Tests for the eval_protocol exception handling system. + +Tests the status code to exception mapping functionality: +1. STATUS_CODE_TO_EXCEPTION mapping correctness +2. exception_for_status_code() function behavior +3. Success states don't raise exceptions (0, 100, 101, 102) +4. Error states raise appropriate exceptions (1-16) +5. Exception class inheritance and attributes +6. Integration with existing retry logic +""" + +import pytest +from eval_protocol.models import Status +from eval_protocol.exceptions import ( + exception_for_status_code, + STATUS_CODE_TO_EXCEPTION, + EvalProtocolError, + CancelledError, + UnknownError, + InvalidArgumentError, + DeadlineExceededError, + NotFoundError, + AlreadyExistsError, + PermissionDeniedError, + ResourceExhaustedError, + FailedPreconditionError, + AbortedError, + OutOfRangeError, + UnimplementedError, + InternalError, + UnavailableError, + DataLossError, + UnauthenticatedError, + RolloutFinishedError, + RolloutRunningError, + ScoreInvalidError, +) + + +def test_success_status_codes_no_exception(): + """Test that success/progress status codes don't raise exceptions.""" + success_codes = [ + (0, "OK"), + (100, "FINISHED"), + (101, "RUNNING"), + (102, "SCORE_INVALID"), # Changed to success state + ] + + for code, name in success_codes: + exception = exception_for_status_code(code) + assert exception is None, f"Status code {code} ({name}) should not raise exception" + + +def test_error_status_codes_raise_exceptions(): + """Test that error status codes raise appropriate exceptions.""" + error_test_cases = [ + (1, CancelledError, "CANCELLED"), + (2, UnknownError, "UNKNOWN"), + (3, InvalidArgumentError, "INVALID_ARGUMENT"), + (4, DeadlineExceededError, "DEADLINE_EXCEEDED"), + (5, NotFoundError, "NOT_FOUND"), + (6, AlreadyExistsError, "ALREADY_EXISTS"), + (7, PermissionDeniedError, "PERMISSION_DENIED"), + (8, ResourceExhaustedError, "RESOURCE_EXHAUSTED"), + (9, FailedPreconditionError, "FAILED_PRECONDITION"), + (10, AbortedError, "ABORTED"), + (11, OutOfRangeError, "OUT_OF_RANGE"), + (12, UnimplementedError, "UNIMPLEMENTED"), + (13, InternalError, "INTERNAL"), + (14, UnavailableError, "UNAVAILABLE"), + (15, DataLossError, "DATA_LOSS"), + (16, UnauthenticatedError, "UNAUTHENTICATED"), + ] + + for code, expected_exception_class, name in error_test_cases: + exception = exception_for_status_code(code) + assert exception is not None, f"Status code {code} ({name}) should raise exception" + assert isinstance(exception, expected_exception_class), ( + f"Status code {code} should raise {expected_exception_class.__name__}" + ) + assert isinstance(exception, EvalProtocolError), "All exceptions should inherit from EvalProtocolError" + + +def test_status_code_mapping_completeness(): + """Test that STATUS_CODE_TO_EXCEPTION mapping covers all expected codes.""" + expected_codes = [ + 0, # OK + 1, + 2, + 3, + 4, + 5, + 6, + 7, + 8, + 9, + 10, + 11, + 12, + 13, + 14, + 15, + 16, # Standard gRPC codes + 100, + 101, + 102, # Custom EP codes + ] + + for code in expected_codes: + assert code in STATUS_CODE_TO_EXCEPTION, f"Status code {code} missing from mapping" + + +def test_invalid_status_codes(): + """Test behavior with invalid/unknown status codes.""" + invalid_codes = [-1, 17, 99, 103, 999] + + for code in invalid_codes: + exception = exception_for_status_code(code) + assert exception is None, f"Invalid status code {code} should return None" + + +def test_exception_attributes(): + """Test that exceptions have the expected attributes.""" + # Test a few exception types + test_cases = [ + (1, CancelledError, "CANCELLED"), + (5, NotFoundError, "NOT_FOUND"), + (13, InternalError, "INTERNAL"), + ] + + for code, expected_class, name in test_cases: + exception = exception_for_status_code(code) + assert hasattr(expected_class, "status_code"), f"{expected_class.__name__} should have status_code attribute" + assert expected_class.status_code == code, f"{expected_class.__name__}.status_code should be {code}" + + +def test_exception_raising_integration(): + """Test the pattern used in RemoteRolloutProcessor.""" + # Simulate the pattern used in remote_rollout_processor.py + status_codes_to_test = [ + (0, False), # OK - should not raise + (5, True), # NOT_FOUND - should raise NotFoundError + (13, True), # INTERNAL - should raise InternalError + (100, False), # FINISHED - should not raise + ] + + for status_code, should_raise in status_codes_to_test: + exception = exception_for_status_code(status_code) + + if should_raise: + assert exception is not None, f"Status code {status_code} should create exception" + # Test that we can raise it + with pytest.raises(EvalProtocolError): + raise exception + else: + assert exception is None, f"Status code {status_code} should not create exception" + + +def test_status_code_enum_consistency(): + """Test that our mapping is consistent with Status.Code enum.""" + # Test that our exception mapping aligns with Status.Code enum + status_code_mapping = { + Status.Code.OK: None, + Status.Code.CANCELLED: CancelledError, + Status.Code.UNKNOWN: UnknownError, + Status.Code.INVALID_ARGUMENT: InvalidArgumentError, + Status.Code.DEADLINE_EXCEEDED: DeadlineExceededError, + Status.Code.NOT_FOUND: NotFoundError, + Status.Code.ALREADY_EXISTS: AlreadyExistsError, + Status.Code.PERMISSION_DENIED: PermissionDeniedError, + Status.Code.RESOURCE_EXHAUSTED: ResourceExhaustedError, + Status.Code.FAILED_PRECONDITION: FailedPreconditionError, + Status.Code.ABORTED: AbortedError, + Status.Code.OUT_OF_RANGE: OutOfRangeError, + Status.Code.UNIMPLEMENTED: UnimplementedError, + Status.Code.INTERNAL: InternalError, + Status.Code.UNAVAILABLE: UnavailableError, + Status.Code.DATA_LOSS: DataLossError, + Status.Code.UNAUTHENTICATED: UnauthenticatedError, + Status.Code.FINISHED: None, + Status.Code.RUNNING: None, + Status.Code.SCORE_INVALID: None, + } + + for status_code_enum, expected_exception_class in status_code_mapping.items(): + code_value = int(status_code_enum) + actual_exception_class = STATUS_CODE_TO_EXCEPTION.get(code_value) + + if expected_exception_class is None: + assert actual_exception_class is None, ( + f"Status.Code.{status_code_enum.name} ({code_value}) should map to None" + ) + else: + assert actual_exception_class == expected_exception_class, ( + f"Status.Code.{status_code_enum.name} ({code_value}) should map to {expected_exception_class.__name__}" + ) + + +def test_exception_inheritance(): + """Test that all exception classes properly inherit from EvalProtocolError.""" + exception_classes = [ + CancelledError, + UnknownError, + InvalidArgumentError, + DeadlineExceededError, + NotFoundError, + AlreadyExistsError, + PermissionDeniedError, + ResourceExhaustedError, + FailedPreconditionError, + AbortedError, + OutOfRangeError, + UnimplementedError, + InternalError, + UnavailableError, + DataLossError, + UnauthenticatedError, + RolloutFinishedError, + RolloutRunningError, + ScoreInvalidError, + ] + + for exception_class in exception_classes: + assert issubclass(exception_class, EvalProtocolError), ( + f"{exception_class.__name__} should inherit from EvalProtocolError" + ) + assert issubclass(exception_class, Exception), f"{exception_class.__name__} should inherit from Exception" + + +def test_real_world_usage_scenarios(): + """Test realistic usage patterns from RemoteRolloutProcessor.""" + # Test scenarios that might occur in practice + scenarios = [ + # Success scenarios + {"status_code": 0, "description": "Successful API call", "should_raise": False}, + {"status_code": 100, "description": "Rollout completed successfully", "should_raise": False}, + {"status_code": 101, "description": "Rollout still in progress", "should_raise": False}, + # Error scenarios that should trigger retry logic + { + "status_code": 4, + "description": "Request timeout", + "should_raise": True, + "expected_exception": DeadlineExceededError, + }, + { + "status_code": 5, + "description": "Model not found", + "should_raise": True, + "expected_exception": NotFoundError, + }, + { + "status_code": 7, + "description": "API key invalid", + "should_raise": True, + "expected_exception": PermissionDeniedError, + }, + { + "status_code": 8, + "description": "Rate limit exceeded", + "should_raise": True, + "expected_exception": ResourceExhaustedError, + }, + { + "status_code": 13, + "description": "Internal server error", + "should_raise": True, + "expected_exception": InternalError, + }, + { + "status_code": 14, + "description": "Service temporarily unavailable", + "should_raise": True, + "expected_exception": UnavailableError, + }, + ] + + for scenario in scenarios: + status_code = scenario["status_code"] + description = scenario["description"] + should_raise = scenario["should_raise"] + + # This is the pattern used in RemoteRolloutProcessor + exception = exception_for_status_code(status_code) + + if should_raise: + expected_exception = scenario["expected_exception"] + assert exception is not None, f"Scenario '{description}' should create exception" + assert isinstance(exception, expected_exception), ( + f"Scenario '{description}' should create {expected_exception.__name__}" + ) + + # Test that the exception can be raised and caught for retry logic + with pytest.raises(expected_exception): + raise exception + + else: + assert exception is None, f"Scenario '{description}' should not create exception" + + +def test_exception_status_code_attributes(): + """Test that all exceptions have correct status_code attributes.""" + expected_mappings = [ + (CancelledError, 1), + (UnknownError, 2), + (InvalidArgumentError, 3), + (DeadlineExceededError, 4), + (NotFoundError, 5), + (AlreadyExistsError, 6), + (PermissionDeniedError, 7), + (ResourceExhaustedError, 8), + (FailedPreconditionError, 9), + (AbortedError, 10), + (OutOfRangeError, 11), + (UnimplementedError, 12), + (InternalError, 13), + (UnavailableError, 14), + (DataLossError, 15), + (UnauthenticatedError, 16), + (RolloutFinishedError, 100), + (RolloutRunningError, 101), + (ScoreInvalidError, 102), + ] + + for exception_class, expected_code in expected_mappings: + assert hasattr(exception_class, "status_code"), f"{exception_class.__name__} should have status_code attribute" + assert exception_class.status_code == expected_code, ( + f"{exception_class.__name__}.status_code should be {expected_code}" + ) + + +def test_integration_with_retry_logic(): + """Test that our exceptions integrate properly with existing retry logic.""" + from eval_protocol.pytest.exception_config import DEFAULT_RETRYABLE_EXCEPTIONS + + # Test that our error exceptions are covered by retry logic + our_error_exceptions = [ + UnknownError, + DeadlineExceededError, + NotFoundError, + PermissionDeniedError, + UnavailableError, + UnauthenticatedError, + ResourceExhaustedError, + ] + + for exception_class in our_error_exceptions: + assert exception_class in DEFAULT_RETRYABLE_EXCEPTIONS, ( + f"{exception_class.__name__} should be in DEFAULT_RETRYABLE_EXCEPTIONS for retry support" + ) + + +def test_exception_message_preservation(): + """Test that error messages are properly preserved in exceptions.""" + test_cases = [ + (13, "test error", InternalError), + (5, "Model xyz not found", NotFoundError), + (7, "Invalid API key", PermissionDeniedError), + ] + + for status_code, message, expected_exception_class in test_cases: + # Test with message + exception = exception_for_status_code(status_code, message) + assert exception is not None + assert isinstance(exception, expected_exception_class) + assert str(exception) == message, f"Exception should preserve message '{message}'" + + # Test without message (should still work) + exception_no_msg = exception_for_status_code(status_code) + assert exception_no_msg is not None + assert isinstance(exception_no_msg, expected_exception_class)