-
Notifications
You must be signed in to change notification settings - Fork 16
Harden /evaluate error handling and remove mutable kwargs default #438
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -243,3 +243,6 @@ package.json | |
| tau2-bench | ||
| *.err | ||
| eval-protocol | ||
| _pytest_deps/ | ||
| .test_deps/ | ||
| .test_deps/ | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,4 +1,5 @@ | ||
| import importlib | ||
| import logging | ||
| import os | ||
| from typing import Any, Dict, List, Optional | ||
|
|
||
|
|
@@ -9,12 +10,15 @@ | |
| # Assuming these models are correctly defined in eval_protocol.models | ||
| from eval_protocol.models import EvaluateResult, Message | ||
|
|
||
| logger = logging.getLogger(__name__) | ||
|
|
||
|
|
||
| # --- Request and Response Models --- | ||
| class EvaluationRequest(BaseModel): | ||
| messages: List[Dict[str, Any]] # Could also be List[Message] if we enforce that model on input | ||
| ground_truth: Optional[str] = None | ||
| kwargs: Optional[Dict[str, Any]] = {} | ||
| # Avoid shared mutable default across requests. | ||
| kwargs: Optional[Dict[str, Any]] = None | ||
|
|
||
|
|
||
| # --- Global variable to store the loaded reward function --- | ||
|
|
@@ -74,8 +78,10 @@ async def evaluate_endpoint(request: EvaluationRequest): | |
| if not isinstance(result, EvaluateResult): | ||
| # This case should ideally not happen if functions are correctly decorated | ||
| # and return EvaluateResult, but good to have a fallback. | ||
| print( | ||
| f"Warning: Reward function '{_REWARD_FUNCTION_NAME}' did not return an EvaluateResult instance. Type: {type(result)}" | ||
| logger.warning( | ||
| "Reward function '%s' did not return an EvaluateResult instance. Type: %s", | ||
| _REWARD_FUNCTION_NAME, | ||
| type(result), | ||
| ) | ||
| # Attempt to construct an EvaluateResult if it's a dict-like object, | ||
| # otherwise, this will raise an error or return a poorly formed response. | ||
|
|
@@ -89,15 +95,18 @@ async def evaluate_endpoint(request: EvaluationRequest): | |
|
|
||
| return result | ||
| except ValidationError as ve: # Pydantic validation error from reward function's input/output | ||
| print(f"Validation Error calling reward function '{_REWARD_FUNCTION_NAME}': {ve}") | ||
| logger.warning( | ||
| "Validation error calling reward function '%s': %s", | ||
| _REWARD_FUNCTION_NAME, | ||
| ve, | ||
| ) | ||
| raise HTTPException( | ||
| status_code=422, | ||
| detail=f"Input/Output validation error for reward function: {ve.errors()}", | ||
| ) | ||
| except Exception as e: | ||
| print(f"Error during evaluation with reward function '{_REWARD_FUNCTION_NAME}': {e}") | ||
| # Consider logging the full traceback here | ||
| raise HTTPException(status_code=500, detail=f"Internal server error during evaluation: {str(e)}") | ||
| logger.exception("Error during evaluation with reward function '%s'", _REWARD_FUNCTION_NAME) | ||
| raise HTTPException(status_code=500, detail="Internal server error during evaluation.") | ||
|
|
||
|
|
||
| @app.get("/health") | ||
|
|
@@ -121,9 +130,9 @@ def load_reward_function(import_string: str): | |
| module = importlib.import_module(module_path) | ||
| _LOADED_REWARD_FUNCTION = getattr(module, function_name) | ||
| _REWARD_FUNCTION_NAME = import_string | ||
| print(f"Successfully loaded reward function: {_REWARD_FUNCTION_NAME}") | ||
| logger.info("Successfully loaded reward function: %s", _REWARD_FUNCTION_NAME) | ||
| except Exception as e: | ||
| print(f"Error loading reward function from '{import_string}': {e}") | ||
| logger.exception("Error loading reward function from '%s'", import_string) | ||
| _LOADED_REWARD_FUNCTION = None | ||
| _REWARD_FUNCTION_NAME = "Error loading" | ||
| raise # Re-raise to make it fatal if loading fails on startup | ||
|
|
@@ -153,13 +162,16 @@ def load_reward_function(import_string: str): | |
| try: | ||
| load_reward_function(args.import_string) | ||
| except Exception: | ||
| print("Failed to load reward function. Exiting.") | ||
| logger.error("Failed to load reward function. Exiting.") | ||
| exit(1) | ||
|
|
||
| if not _LOADED_REWARD_FUNCTION: | ||
| print(f"Reward function {_REWARD_FUNCTION_NAME} could not be loaded. Server will not start correctly.") | ||
| logger.error( | ||
| "Reward function %s could not be loaded. Server will not start correctly.", | ||
| _REWARD_FUNCTION_NAME, | ||
| ) | ||
| # Depending on desired behavior, could exit here or let it run and fail on /evaluate | ||
| exit(1) | ||
|
|
||
| print(f"Starting server for reward function: {args.import_string} on http://{args.host}:{args.port}") | ||
| logger.info("Starting server for reward function: %s on http://%s:%s", args.import_string, args.host, args.port) | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Missing logging configuration silences startup info messagesMedium Severity The Additional Locations (1) |
||
| uvicorn.run(app, host=args.host, port=args.port) # reload=args.reload for dev | ||


There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Duplicate
.test_deps/entry in.gitignoreLow Severity
.test_deps/appears twice on consecutive lines in.gitignore. One of the duplicate entries can be removed.