Skip to content

Commit cb3206c

Browse files
committed
adding response quality validation for retry
1 parent f409213 commit cb3206c

File tree

8 files changed

+427
-25
lines changed

8 files changed

+427
-25
lines changed

eval_protocol/exceptions.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -134,6 +134,12 @@ class ScoreInvalidError(EvalProtocolError):
134134
status_code = 102
135135

136136

137+
class ResponseQualityError(EvalProtocolError):
138+
"""Response quality check failed (Status.Code.RESPONSE_QUALITY_ERROR = 103)"""
139+
140+
status_code = 103
141+
142+
137143
# Convenience mapping from status codes to exception classes
138144
# Only actual error conditions should raise exceptions
139145
STATUS_CODE_TO_EXCEPTION = {
@@ -157,6 +163,7 @@ class ScoreInvalidError(EvalProtocolError):
157163
100: None, # FINISHED - success, no exception
158164
101: None, # RUNNING - in progress, no exception
159165
102: None, # SCORE_INVALID - success, no exception
166+
103: ResponseQualityError, # RESPONSE_QUALITY_ERROR - quality check failed
160167
}
161168

162169

eval_protocol/models.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,7 @@ class Code(int, Enum):
117117
FINISHED = 100
118118
RUNNING = 101
119119
SCORE_INVALID = 102
120+
RESPONSE_QUALITY_ERROR = 103
120121

121122
@classmethod
122123
def rollout_running(cls) -> "Status":
@@ -367,6 +368,13 @@ def score_invalid(
367368
"""Create a status indicating the score is invalid."""
368369
return cls(code=cls.Code.SCORE_INVALID, message=message, details=details or [])
369370

371+
@classmethod
372+
def response_quality_error(
373+
cls, message: str = "Response quality check failed", details: Optional[List[Dict[str, Any]]] = None
374+
) -> "Status":
375+
"""Create a status indicating the response quality check failed."""
376+
return cls(code=cls.Code.RESPONSE_QUALITY_ERROR, message=message, details=details or [])
377+
370378
def is_running(self) -> bool:
371379
"""Check if the status indicates the rollout is running."""
372380
return self.code == self.Code.RUNNING

eval_protocol/pytest/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from .evaluation_test import evaluation_test
99
from .exception_config import ExceptionHandlerConfig, BackoffConfig, get_default_exception_handler_config
1010
from .rollout_processor import RolloutProcessor
11+
from .rollout_result_post_processor import RolloutResultPostProcessor, NoOpRolloutResultPostProcessor
1112
from .types import RolloutProcessorConfig
1213

1314
# Conditional import for optional dependencies
@@ -42,6 +43,8 @@
4243
"ExceptionHandlerConfig",
4344
"BackoffConfig",
4445
"get_default_exception_handler_config",
46+
"RolloutResultPostProcessor",
47+
"NoOpRolloutResultPostProcessor",
4548
]
4649

4750
# Only add to __all__ if available

eval_protocol/pytest/evaluation_test_utils.py

Lines changed: 26 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
ServerMode,
2929
)
3030
from eval_protocol.pytest.exception_config import get_default_exception_handler_config
31+
from eval_protocol.exceptions import ResponseQualityError
3132

3233
import logging
3334
import json
@@ -363,7 +364,21 @@ async def execute_row_with_backoff_retry(row: EvaluationRow) -> EvaluationRow:
363364
"""Execute rollout for a single row with backoff retry."""
364365
retry_config = replace(config, kwargs={**(config.kwargs or {}), "start_server": False})
365366
retry_tasks = rollout_processor([row], retry_config)
366-
return await retry_tasks[0]
367+
result = await retry_tasks[0]
368+
369+
# Apply post-processing quality checks if configured
370+
# This must be inside the retry function so ResponseQualityError can trigger retries
371+
if config.post_processor is not None:
372+
try:
373+
config.post_processor.process(result)
374+
except ResponseQualityError as quality_error:
375+
# Re-raise ResponseQualityError to trigger retry logic
376+
raise quality_error
377+
except Exception as post_process_error:
378+
# Wrap unexpected post-processor errors in ResponseQualityError
379+
raise ResponseQualityError(f"Post-processor failed: {post_process_error}") from post_process_error
380+
381+
return result
367382

368383
async def execute_row_with_backoff(task: asyncio.Task[EvaluationRow], row: EvaluationRow) -> EvaluationRow:
369384
"""Execute a single row task with backoff retry."""
@@ -372,6 +387,15 @@ async def execute_row_with_backoff(task: asyncio.Task[EvaluationRow], row: Evalu
372387
# Try original task first
373388
result = await task # pyright: ignore[reportUnknownVariableType]
374389

390+
# Apply post-processing quality checks if configured
391+
if config.post_processor is not None:
392+
try:
393+
config.post_processor.process(result)
394+
except ResponseQualityError as quality_error:
395+
raise quality_error
396+
except Exception as post_process_error:
397+
raise ResponseQualityError(f"Post-processor failed: {post_process_error}") from post_process_error
398+
375399
_set_rollout_status_to_finished(result)
376400

377401
return result # pyright: ignore[reportUnknownVariableType]
@@ -384,9 +408,9 @@ async def execute_row_with_backoff(task: asyncio.Task[EvaluationRow], row: Evalu
384408

385409
if is_retryable and not should_giveup:
386410
# Use shared backoff function for retryable exceptions
411+
# Note: post-processing is handled inside execute_row_with_backoff_retry
387412
try:
388413
result = await execute_row_with_backoff_retry(row)
389-
390414
_set_rollout_status_to_finished(result)
391415

392416
return result

eval_protocol/pytest/exception_config.py

Lines changed: 112 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
import os
66
from dataclasses import dataclass, field
7-
from typing import Callable, Set, Type, Union
7+
from typing import Callable, Dict, Set, Type, Union
88

99
import backoff
1010

@@ -47,6 +47,7 @@
4747
eval_protocol.exceptions.UnavailableError,
4848
eval_protocol.exceptions.UnauthenticatedError,
4949
eval_protocol.exceptions.ResourceExhaustedError,
50+
eval_protocol.exceptions.ResponseQualityError,
5051
}
5152

5253

@@ -78,39 +79,112 @@ class BackoffConfig:
7879
# Optional custom giveup function - if provided, overrides the default exception handling logic
7980
giveup_func: Callable[[Exception], bool] = lambda e: False
8081

81-
def get_backoff_decorator(self, exceptions: Set[Type[Exception]]):
82-
"""Get the appropriate backoff decorator based on configuration."""
82+
def get_backoff_decorator(self, exceptions: Set[Type[Exception]], exception_backoff_overrides: Dict[Type[Exception], "BackoffConfig"] | None = None):
83+
"""Get the appropriate backoff decorator based on configuration.
84+
85+
Args:
86+
exceptions: Set of exception types to retry
87+
exception_backoff_overrides: Optional mapping of exception types to custom backoff configs.
88+
If an exception type has an override, that config will be used instead of this one.
89+
"""
8390
if not exceptions:
8491
# If no exceptions specified, return a no-op decorator
8592
def no_op_decorator(func):
8693
return func
8794

8895
return no_op_decorator
8996

90-
if self.strategy == "expo":
97+
# If no overrides, use simple decorator for all exceptions
98+
if not exception_backoff_overrides:
99+
return self._create_single_decorator(exceptions, self)
100+
101+
# Group exceptions by their backoff config to avoid double backoff
102+
# Each exception type gets exactly one decorator based on its config
103+
# Use a tuple of config attributes as the key since BackoffConfig is not hashable
104+
config_to_exceptions: Dict[tuple, tuple[Set[Type[Exception]], "BackoffConfig"]] = {}
105+
106+
for exc_type in exceptions:
107+
if exc_type in exception_backoff_overrides:
108+
override_config = exception_backoff_overrides[exc_type]
109+
else:
110+
override_config = self
111+
112+
# Create a hashable key from config attributes
113+
# Note: jitter and giveup_func are callable, which are hashable in Python
114+
config_key = (
115+
override_config.strategy,
116+
override_config.base_delay,
117+
override_config.max_delay,
118+
override_config.max_tries,
119+
override_config.factor,
120+
id(override_config.jitter) if override_config.jitter is not None else None,
121+
id(override_config.giveup_func) if override_config.giveup_func is not None else None,
122+
override_config.raise_on_giveup,
123+
)
124+
125+
if config_key not in config_to_exceptions:
126+
config_to_exceptions[config_key] = (set(), override_config)
127+
exc_set, _ = config_to_exceptions[config_key]
128+
exc_set.add(exc_type)
129+
130+
# If all exceptions use the same config, use a single decorator
131+
if len(config_to_exceptions) == 1:
132+
exc_set, config = next(iter(config_to_exceptions.values()))
133+
return self._create_single_decorator(exc_set, config)
134+
135+
# Create separate decorators for each config group
136+
# Each exception type gets exactly one decorator, preventing double backoff
137+
decorators_by_config: list[tuple[Set[Type[Exception]], Callable]] = []
138+
139+
for exc_set, config in config_to_exceptions.values():
140+
decorator = self._create_single_decorator(exc_set, config)
141+
if decorator:
142+
decorators_by_config.append((exc_set, decorator))
143+
144+
# Create a combined decorator that applies all decorators
145+
# Each decorator only catches exceptions in its exception set, so no double backoff
146+
def combined_decorator(func):
147+
decorated_func = func
148+
149+
# Apply each decorator in order (inner to outer)
150+
# Each decorator only catches exceptions in its specific exception set
151+
# Since exception sets are disjoint (grouped by config), no double backoff
152+
for exc_set, decorator in decorators_by_config:
153+
decorated_func = decorator(decorated_func)
154+
155+
return decorated_func
156+
157+
return combined_decorator
158+
159+
def _create_single_decorator(self, exc_set: Set[Type[Exception]], config: "BackoffConfig"):
160+
"""Create a single backoff decorator for a set of exceptions."""
161+
if not exc_set:
162+
return None
163+
164+
if config.strategy == "expo":
91165
return backoff.on_exception(
92166
backoff.expo,
93-
tuple(exceptions),
94-
max_tries=self.max_tries,
95-
base=self.base_delay,
96-
max_value=self.max_delay,
97-
factor=self.factor,
98-
jitter=self.jitter,
99-
giveup=self.giveup_func,
100-
raise_on_giveup=self.raise_on_giveup,
167+
tuple(exc_set),
168+
max_tries=config.max_tries,
169+
base=config.base_delay,
170+
max_value=config.max_delay,
171+
factor=config.factor,
172+
jitter=config.jitter,
173+
giveup=config.giveup_func,
174+
raise_on_giveup=config.raise_on_giveup,
101175
)
102-
elif self.strategy == "constant":
176+
elif config.strategy == "constant":
103177
return backoff.on_exception(
104178
backoff.constant,
105-
tuple(exceptions),
106-
max_tries=self.max_tries,
107-
interval=self.base_delay,
108-
jitter=self.jitter,
109-
giveup=self.giveup_func,
110-
raise_on_giveup=self.raise_on_giveup,
179+
tuple(exc_set),
180+
max_tries=config.max_tries,
181+
interval=config.base_delay,
182+
jitter=config.jitter,
183+
giveup=config.giveup_func,
184+
raise_on_giveup=config.raise_on_giveup,
111185
)
112186
else:
113-
raise ValueError(f"Unknown backoff strategy: {self.strategy}")
187+
raise ValueError(f"Unknown backoff strategy: {config.strategy}")
114188

115189

116190
@dataclass
@@ -123,6 +197,10 @@ class ExceptionHandlerConfig:
123197
# Backoff configuration
124198
backoff_config: BackoffConfig = field(default_factory=BackoffConfig)
125199

200+
# Per-exception backoff overrides - allows custom backoff config for specific exception types
201+
# For example, ResponseQualityError can use no backoff (base_delay=0, max_delay=0)
202+
exception_backoff_overrides: Dict[Type[Exception], BackoffConfig] = field(default_factory=dict)
203+
126204
def __post_init__(self):
127205
"""Automatically apply environment variable overrides after initialization."""
128206
# Override backoff settings from environment variables
@@ -133,10 +211,23 @@ def __post_init__(self):
133211
if "EP_FAIL_ON_MAX_RETRY" in os.environ:
134212
fail_on_max_retry = os.environ["EP_FAIL_ON_MAX_RETRY"].lower()
135213
self.backoff_config.raise_on_giveup = fail_on_max_retry != "false"
214+
215+
# Set default no-backoff config for ResponseQualityError if not already set
216+
if eval_protocol.exceptions.ResponseQualityError not in self.exception_backoff_overrides:
217+
# Default: no backoff for ResponseQualityError (immediate retry)
218+
self.exception_backoff_overrides[eval_protocol.exceptions.ResponseQualityError] = BackoffConfig(
219+
strategy="constant",
220+
base_delay=0.0,
221+
max_delay=0.0,
222+
max_tries=self.backoff_config.max_tries,
223+
)
136224

137225
def get_backoff_decorator(self):
138226
"""Get the backoff decorator configured for this exception handler."""
139-
return self.backoff_config.get_backoff_decorator(self.retryable_exceptions)
227+
return self.backoff_config.get_backoff_decorator(
228+
self.retryable_exceptions,
229+
self.exception_backoff_overrides if self.exception_backoff_overrides else None
230+
)
140231

141232

142233
def get_default_exception_handler_config() -> ExceptionHandlerConfig:

eval_protocol/pytest/types.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111

1212
from ..models import CompletionParams, EvaluationRow, Message
1313
from .exception_config import ExceptionHandlerConfig
14+
from .rollout_result_post_processor import RolloutResultPostProcessor
1415

1516
ModelParam = str # gpt-4o, gpt-4o-mini, accounts/fireworks/models/llama-3.1-8b-instruct
1617
DatasetPathParam = str
@@ -75,3 +76,4 @@ class RolloutProcessorConfig:
7576
default_factory=dict
7677
) # any additional kwargs to pass to the rollout processor
7778
exception_handler_config: ExceptionHandlerConfig | None = None # configuration for exception handling with backoff
79+
post_processor: RolloutResultPostProcessor | None = None # optional post-processor for quality checks

0 commit comments

Comments
 (0)