Skip to content

Commit 4971888

Browse files
authored
Response validation plugin (#351)
* adding response quality validation for retry * Add rollout result post processor * simplify the config
1 parent 79e3686 commit 4971888

File tree

10 files changed

+326
-9
lines changed

10 files changed

+326
-9
lines changed

eval_protocol/exceptions.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -134,6 +134,12 @@ class ScoreInvalidError(EvalProtocolError):
134134
status_code = 102
135135

136136

137+
class ResponseQualityError(EvalProtocolError):
138+
"""Response quality check failed (Status.Code.RESPONSE_QUALITY_ERROR = 103)"""
139+
140+
status_code = 103
141+
142+
137143
# Convenience mapping from status codes to exception classes
138144
# Only actual error conditions should raise exceptions
139145
STATUS_CODE_TO_EXCEPTION = {
@@ -157,6 +163,7 @@ class ScoreInvalidError(EvalProtocolError):
157163
100: None, # FINISHED - success, no exception
158164
101: None, # RUNNING - in progress, no exception
159165
102: None, # SCORE_INVALID - success, no exception
166+
103: ResponseQualityError, # RESPONSE_QUALITY_ERROR - quality check failed
160167
}
161168

162169

eval_protocol/models.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,7 @@ class Code(int, Enum):
117117
FINISHED = 100
118118
RUNNING = 101
119119
SCORE_INVALID = 102
120+
RESPONSE_QUALITY_ERROR = 103
120121

121122
@classmethod
122123
def rollout_running(cls) -> "Status":
@@ -367,6 +368,13 @@ def score_invalid(
367368
"""Create a status indicating the score is invalid."""
368369
return cls(code=cls.Code.SCORE_INVALID, message=message, details=details or [])
369370

371+
@classmethod
372+
def response_quality_error(
373+
cls, message: str = "Response quality check failed", details: Optional[List[Dict[str, Any]]] = None
374+
) -> "Status":
375+
"""Create a status indicating the response quality check failed."""
376+
return cls(code=cls.Code.RESPONSE_QUALITY_ERROR, message=message, details=details or [])
377+
370378
def is_running(self) -> bool:
371379
"""Check if the status indicates the rollout is running."""
372380
return self.code == self.Code.RUNNING

eval_protocol/pytest/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from .evaluation_test import evaluation_test
99
from .exception_config import ExceptionHandlerConfig, BackoffConfig, get_default_exception_handler_config
1010
from .rollout_processor import RolloutProcessor
11+
from .rollout_result_post_processor import RolloutResultPostProcessor, NoOpRolloutResultPostProcessor
1112
from .types import RolloutProcessorConfig
1213

1314
# Conditional import for optional dependencies
@@ -42,6 +43,8 @@
4243
"ExceptionHandlerConfig",
4344
"BackoffConfig",
4445
"get_default_exception_handler_config",
46+
"RolloutResultPostProcessor",
47+
"NoOpRolloutResultPostProcessor",
4548
]
4649

4750
# Only add to __all__ if available

eval_protocol/pytest/evaluation_test_utils.py

Lines changed: 21 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
ServerMode,
2929
)
3030
from eval_protocol.pytest.exception_config import get_default_exception_handler_config
31+
from eval_protocol.exceptions import ResponseQualityError
3132

3233
import logging
3334
import json
@@ -363,7 +364,18 @@ async def execute_row_with_backoff_retry(row: EvaluationRow) -> EvaluationRow:
363364
"""Execute rollout for a single row with backoff retry."""
364365
retry_config = replace(config, kwargs={**(config.kwargs or {}), "start_server": False})
365366
retry_tasks = rollout_processor([row], retry_config)
366-
return await retry_tasks[0]
367+
result = await retry_tasks[0]
368+
369+
# Apply post-processing quality checks if configured
370+
# This must be inside the retry function so ResponseQualityError can trigger retries
371+
if config.post_processor is not None:
372+
try:
373+
config.post_processor.process(result)
374+
except ResponseQualityError as quality_error:
375+
# Re-raise ResponseQualityError to trigger retry logic
376+
raise quality_error
377+
378+
return result
367379

368380
async def execute_row_with_backoff(task: asyncio.Task[EvaluationRow], row: EvaluationRow) -> EvaluationRow:
369381
"""Execute a single row task with backoff retry."""
@@ -372,6 +384,13 @@ async def execute_row_with_backoff(task: asyncio.Task[EvaluationRow], row: Evalu
372384
# Try original task first
373385
result = await task # pyright: ignore[reportUnknownVariableType]
374386

387+
# Apply post-processing quality checks if configured
388+
if config.post_processor is not None:
389+
try:
390+
config.post_processor.process(result)
391+
except ResponseQualityError as quality_error:
392+
raise quality_error
393+
375394
_set_rollout_status_to_finished(result)
376395

377396
return result # pyright: ignore[reportUnknownVariableType]
@@ -384,9 +403,9 @@ async def execute_row_with_backoff(task: asyncio.Task[EvaluationRow], row: Evalu
384403

385404
if is_retryable and not should_giveup:
386405
# Use shared backoff function for retryable exceptions
406+
# Note: post-processing is handled inside execute_row_with_backoff_retry
387407
try:
388408
result = await execute_row_with_backoff_retry(row)
389-
390409
_set_rollout_status_to_finished(result)
391410

392411
return result

eval_protocol/pytest/exception_config.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
import os
66
from dataclasses import dataclass, field
7-
from typing import Callable, Set, Type, Union
7+
from typing import Callable, Dict, Set, Type, Union
88

99
import backoff
1010

@@ -47,6 +47,7 @@
4747
eval_protocol.exceptions.UnavailableError,
4848
eval_protocol.exceptions.UnauthenticatedError,
4949
eval_protocol.exceptions.ResourceExhaustedError,
50+
eval_protocol.exceptions.ResponseQualityError,
5051
}
5152

5253

@@ -79,7 +80,11 @@ class BackoffConfig:
7980
giveup_func: Callable[[Exception], bool] = lambda e: False
8081

8182
def get_backoff_decorator(self, exceptions: Set[Type[Exception]]):
82-
"""Get the appropriate backoff decorator based on configuration."""
83+
"""Get the appropriate backoff decorator based on configuration.
84+
85+
Args:
86+
exceptions: Set of exception types to retry
87+
"""
8388
if not exceptions:
8489
# If no exceptions specified, return a no-op decorator
8590
def no_op_decorator(func):
@@ -136,7 +141,9 @@ def __post_init__(self):
136141

137142
def get_backoff_decorator(self):
138143
"""Get the backoff decorator configured for this exception handler."""
139-
return self.backoff_config.get_backoff_decorator(self.retryable_exceptions)
144+
return self.backoff_config.get_backoff_decorator(
145+
self.retryable_exceptions
146+
)
140147

141148

142149
def get_default_exception_handler_config() -> ExceptionHandlerConfig:
Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
"""
2+
Rollout result post-processing plugin for quality checks.
3+
4+
This module provides an abstract base class for post-processing rollout results
5+
to guard response quality. Post-processors can validate results and raise
6+
ResponseQualityError if quality checks fail.
7+
"""
8+
9+
from abc import ABC, abstractmethod
10+
11+
from eval_protocol.models import EvaluationRow
12+
13+
14+
class RolloutResultPostProcessor(ABC):
15+
"""
16+
Abstract base class for rollout result post-processing plugins.
17+
18+
Post-processors validate rollout results and can raise ResponseQualityError
19+
if quality checks fail. This allows for customizable quality guards that
20+
can be overridden by users.
21+
"""
22+
23+
@abstractmethod
24+
def process(self, result: EvaluationRow) -> None:
25+
"""
26+
Process and validate a rollout result.
27+
28+
This method should perform quality checks on the result. If quality
29+
checks fail, it should raise ResponseQualityError with an appropriate
30+
message.
31+
32+
Args:
33+
result: The EvaluationRow result from the rollout
34+
35+
Raises:
36+
ResponseQualityError: If quality checks fail
37+
"""
38+
pass
39+
40+
41+
class NoOpRolloutResultPostProcessor(RolloutResultPostProcessor):
42+
"""
43+
Default no-op implementation of RolloutResultPostProcessor.
44+
45+
This implementation does not perform any quality checks and always passes.
46+
Use this as a default when no post-processing is needed.
47+
"""
48+
49+
def process(self, result: EvaluationRow) -> None:
50+
"""
51+
No-op implementation that does not perform any quality checks.
52+
53+
Args:
54+
result: The EvaluationRow result from the rollout
55+
"""
56+
pass
57+

eval_protocol/pytest/types.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111

1212
from ..models import CompletionParams, EvaluationRow, Message
1313
from .exception_config import ExceptionHandlerConfig
14+
from .rollout_result_post_processor import RolloutResultPostProcessor
1415

1516
ModelParam = str # gpt-4o, gpt-4o-mini, accounts/fireworks/models/llama-3.1-8b-instruct
1617
DatasetPathParam = str
@@ -75,3 +76,4 @@ class RolloutProcessorConfig:
7576
default_factory=dict
7677
) # any additional kwargs to pass to the rollout processor
7778
exception_handler_config: ExceptionHandlerConfig | None = None # configuration for exception handling with backoff
79+
post_processor: RolloutResultPostProcessor | None = None # optional post-processor for quality checks

tests/test_exception_config.py

Lines changed: 114 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,114 @@
1+
"""
2+
Unit tests for exception_config module.
3+
4+
Tests the BackoffConfig and ExceptionHandlerConfig classes, including:
5+
1. Backoff decorator creation
6+
2. Per-exception backoff overrides
7+
3. ResponseQualityError default no-backoff configuration
8+
4. Exception grouping to avoid double backoff
9+
"""
10+
11+
import pytest
12+
from eval_protocol.pytest.exception_config import BackoffConfig, ExceptionHandlerConfig, DEFAULT_RETRYABLE_EXCEPTIONS
13+
from eval_protocol.exceptions import ResponseQualityError
14+
15+
16+
def test_backoff_config_no_exceptions():
17+
"""Test that BackoffConfig returns no-op decorator when no exceptions specified."""
18+
config = BackoffConfig()
19+
decorator = config.get_backoff_decorator(set())
20+
21+
# Should be a no-op decorator
22+
def test_func():
23+
return "test"
24+
25+
decorated = decorator(test_func)
26+
assert decorated() == "test"
27+
assert decorated is test_func # Should be the same function
28+
29+
30+
def test_backoff_config_no_overrides():
31+
"""Test that BackoffConfig creates a single decorator."""
32+
config = BackoffConfig(strategy="constant", base_delay=0.1, max_tries=2)
33+
exceptions = {ConnectionError, TimeoutError}
34+
35+
decorator = config.get_backoff_decorator(exceptions)
36+
assert decorator is not None
37+
38+
# Decorator should be callable
39+
def test_func():
40+
raise ConnectionError("test")
41+
42+
decorated = decorator(test_func)
43+
assert callable(decorated)
44+
45+
46+
def test_exception_handler_config_default_response_quality_error():
47+
"""Test that ExceptionHandlerConfig includes ResponseQualityError by default."""
48+
config = ExceptionHandlerConfig()
49+
50+
# ResponseQualityError should be in retryable_exceptions
51+
assert ResponseQualityError in config.retryable_exceptions
52+
53+
54+
def test_exception_handler_config_get_backoff_decorator():
55+
"""Test that ExceptionHandlerConfig.get_backoff_decorator() works correctly."""
56+
config = ExceptionHandlerConfig()
57+
decorator = config.get_backoff_decorator()
58+
59+
assert decorator is not None
60+
assert callable(decorator)
61+
62+
# Should be able to decorate a function
63+
def test_func():
64+
raise ConnectionError("test")
65+
66+
decorated = decorator(test_func)
67+
assert callable(decorated)
68+
69+
70+
def test_backoff_config_expo_strategy():
71+
72+
"""Test that BackoffConfig creates expo decorator correctly."""
73+
config = BackoffConfig(strategy="expo", base_delay=1.0, max_tries=2)
74+
exceptions = {ConnectionError}
75+
76+
decorator = config.get_backoff_decorator(exceptions)
77+
assert decorator is not None
78+
79+
def test_func():
80+
raise ConnectionError("test")
81+
82+
decorated = decorator(test_func)
83+
assert callable(decorated)
84+
85+
86+
def test_backoff_config_constant_strategy():
87+
"""Test that BackoffConfig creates constant decorator correctly."""
88+
config = BackoffConfig(strategy="constant", base_delay=0.1, max_tries=2)
89+
exceptions = {ConnectionError}
90+
91+
decorator = config.get_backoff_decorator(exceptions)
92+
assert decorator is not None
93+
94+
def test_func():
95+
raise ConnectionError("test")
96+
97+
decorated = decorator(test_func)
98+
assert callable(decorated)
99+
100+
101+
def test_backoff_config_invalid_strategy():
102+
"""Test that BackoffConfig raises ValueError for invalid strategy."""
103+
config = BackoffConfig(strategy="invalid", base_delay=1.0, max_tries=2)
104+
exceptions = {ConnectionError}
105+
106+
with pytest.raises(ValueError, match="Unknown backoff strategy"):
107+
config.get_backoff_decorator(exceptions)
108+
109+
110+
def test_exception_handler_config_response_quality_error_in_defaults():
111+
"""Test that ResponseQualityError is in DEFAULT_RETRYABLE_EXCEPTIONS."""
112+
assert ResponseQualityError in DEFAULT_RETRYABLE_EXCEPTIONS
113+
114+

0 commit comments

Comments
 (0)