-
Notifications
You must be signed in to change notification settings - Fork 1.5k
Feat: Content Filtering Exception Handling #3634
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
4ac81d4
658f407
404a833
3771c79
28e20c8
70bcb74
e861a50
51506b2
b67d216
4ac1608
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -24,6 +24,9 @@ | |
| 'UsageLimitExceeded', | ||
| 'ModelAPIError', | ||
| 'ModelHTTPError', | ||
| 'ContentFilterError', | ||
| 'PromptContentFilterError', | ||
| 'ResponseContentFilterError', | ||
| 'IncompleteToolCall', | ||
| 'FallbackExceptionGroup', | ||
| ) | ||
|
|
@@ -179,6 +182,30 @@ def __init__(self, status_code: int, model_name: str, body: object | None = None | |
| super().__init__(model_name=model_name, message=message) | ||
|
|
||
|
|
||
| class ContentFilterError(ModelHTTPError): | ||
| """Raised when content filtering is triggered by the model provider.""" | ||
|
|
||
| def __init__(self, message: str, status_code: int, model_name: str, body: object | None = None): | ||
| super().__init__(status_code, model_name, body) | ||
| self.message = message | ||
|
|
||
|
|
||
| class PromptContentFilterError(ContentFilterError): | ||
| """Raised when the prompt triggers a content filter.""" | ||
|
|
||
| def __init__(self, status_code: int, model_name: str, body: object | None = None): | ||
| message = f"Model '{model_name}' content filter was triggered by the user's prompt" | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Or by the instructions, right? I'd prefer it to be a bit more generic |
||
| super().__init__(message, status_code, model_name, body) | ||
|
|
||
|
|
||
| class ResponseContentFilterError(ContentFilterError): | ||
| """Raised when the generated response triggers a content filter.""" | ||
|
|
||
| def __init__(self, model_name: str, body: object | None = None, status_code: int = 200): | ||
| message = f"Model '{model_name}' triggered its content filter while generating a response" | ||
| super().__init__(message, status_code, model_name, body) | ||
|
|
||
|
|
||
| class FallbackExceptionGroup(ExceptionGroup[Any]): | ||
| """A group of exceptions that can be raised when all fallback models fail.""" | ||
|
|
||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -14,7 +14,7 @@ | |
| from .._run_context import RunContext | ||
| from .._utils import guard_tool_call_id as _guard_tool_call_id | ||
| from ..builtin_tools import CodeExecutionTool, MCPServerTool, MemoryTool, WebFetchTool, WebSearchTool | ||
| from ..exceptions import ModelAPIError, UserError | ||
| from ..exceptions import ModelAPIError, ResponseContentFilterError, UserError | ||
| from ..messages import ( | ||
| BinaryContent, | ||
| BuiltinToolCallPart, | ||
|
|
@@ -526,6 +526,11 @@ def _process_response(self, response: BetaMessage) -> ModelResponse: | |
| if raw_finish_reason := response.stop_reason: # pragma: no branch | ||
| provider_details = {'finish_reason': raw_finish_reason} | ||
| finish_reason = _FINISH_REASON_MAP.get(raw_finish_reason) | ||
| if finish_reason == 'content_filter': | ||
| raise ResponseContentFilterError( | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think we should do this in That would mean that it wouldn't trigger |
||
| model_name=response.model, | ||
| body=response.model_dump(), | ||
| ) | ||
|
|
||
| return ModelResponse( | ||
| parts=items, | ||
|
|
@@ -1243,6 +1248,10 @@ async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]: | |
| if raw_finish_reason := event.delta.stop_reason: # pragma: no branch | ||
| self.provider_details = {'finish_reason': raw_finish_reason} | ||
| self.finish_reason = _FINISH_REASON_MAP.get(raw_finish_reason) | ||
| if self.finish_reason == 'content_filter': | ||
| raise ResponseContentFilterError( | ||
| model_name=self.model_name, | ||
| ) | ||
|
|
||
| elif isinstance(event, BetaRawContentBlockStopEvent): # pragma: no branch | ||
| if isinstance(current_block, BetaMCPToolUseBlock): | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -14,7 +14,7 @@ | |
| from .._output import OutputObjectDefinition | ||
| from .._run_context import RunContext | ||
| from ..builtin_tools import CodeExecutionTool, ImageGenerationTool, WebFetchTool, WebSearchTool | ||
| from ..exceptions import ModelAPIError, ModelHTTPError, UserError | ||
| from ..exceptions import ModelAPIError, ModelHTTPError, ResponseContentFilterError, UserError | ||
| from ..messages import ( | ||
| BinaryContent, | ||
| BuiltinToolCallPart, | ||
|
|
@@ -495,8 +495,8 @@ def _process_response(self, response: GenerateContentResponse) -> ModelResponse: | |
|
|
||
| if candidate.content is None or candidate.content.parts is None: | ||
| if finish_reason == 'content_filter' and raw_finish_reason: | ||
| raise UnexpectedModelBehavior( | ||
| f'Content filter {raw_finish_reason.value!r} triggered', response.model_dump_json() | ||
| raise ResponseContentFilterError( | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is a breaking change for users that currently have
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Combined with the above, the solution may be to remove this check here, and to implement the same If we want to get some details from the response up to that level, we can store them in |
||
| model_name=response.model_version or self._model_name, body=response.model_dump_json() | ||
| ) | ||
| parts = [] # pragma: no cover | ||
| else: | ||
|
|
@@ -697,10 +697,8 @@ async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]: | |
| yield self._parts_manager.handle_part(vendor_part_id=uuid4(), part=web_fetch_return) | ||
|
|
||
| if candidate.content is None or candidate.content.parts is None: | ||
| if self.finish_reason == 'content_filter' and raw_finish_reason: # pragma: no cover | ||
| raise UnexpectedModelBehavior( | ||
| f'Content filter {raw_finish_reason.value!r} triggered', chunk.model_dump_json() | ||
| ) | ||
| if self.finish_reason == 'content_filter' and raw_finish_reason: | ||
| raise ResponseContentFilterError(model_name=self.model_name, body=chunk.model_dump_json()) | ||
| else: # pragma: no cover | ||
| continue | ||
|
|
||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -20,7 +20,7 @@ | |
| from .._thinking_part import split_content_into_text_and_thinking | ||
| from .._utils import guard_tool_call_id as _guard_tool_call_id, now_utc as _now_utc, number_to_datetime | ||
| from ..builtin_tools import CodeExecutionTool, ImageGenerationTool, MCPServerTool, WebSearchTool | ||
| from ..exceptions import UserError | ||
| from ..exceptions import PromptContentFilterError, ResponseContentFilterError, UserError | ||
| from ..messages import ( | ||
| AudioUrl, | ||
| BinaryContent, | ||
|
|
@@ -160,6 +160,24 @@ | |
| } | ||
|
|
||
|
|
||
| def _check_azure_content_filter(e: APIStatusError, model_name: str) -> None: | ||
| """Check if the error is an Azure content filter error and raise PromptContentFilterError if so.""" | ||
| if e.status_code == 400: | ||
| body_any: Any = e.body | ||
|
|
||
AlanPonnachan marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| if isinstance(body_any, dict): | ||
| body_dict = cast(dict[str, Any], body_any) | ||
|
|
||
| if (error := body_dict.get('error')) and isinstance(error, dict): | ||
| error_dict = cast(dict[str, Any], error) | ||
| if error_dict.get('code') == 'content_filter': | ||
| raise PromptContentFilterError( | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What do you think about instead handling this by returning a I also don't think they distinguish between input/request and output/response content filter errors, so I don't think we need to here either. |
||
| status_code=e.status_code, | ||
| model_name=model_name, | ||
| body=body_dict, | ||
| ) from e | ||
|
|
||
|
|
||
| class OpenAIChatModelSettings(ModelSettings, total=False): | ||
| """Settings used for an OpenAI model request.""" | ||
|
|
||
|
|
@@ -555,6 +573,8 @@ async def _completions_create( | |
| ) | ||
| except APIStatusError as e: | ||
| if (status_code := e.status_code) >= 400: | ||
| _check_azure_content_filter(e, self.model_name) | ||
|
|
||
| raise ModelHTTPError(status_code=status_code, model_name=self.model_name, body=e.body) from e | ||
| raise # pragma: lax no cover | ||
| except APIConnectionError as e: | ||
|
|
@@ -601,6 +621,13 @@ def _process_response(self, response: chat.ChatCompletion | str) -> ModelRespons | |
| raise UnexpectedModelBehavior(f'Invalid response from {self.system} chat completions endpoint: {e}') from e | ||
|
|
||
| choice = response.choices[0] | ||
|
|
||
| if choice.finish_reason == 'content_filter': | ||
| raise ResponseContentFilterError( | ||
| model_name=response.model, | ||
| body=response.model_dump(), | ||
| ) | ||
|
|
||
| items: list[ModelResponsePart] = [] | ||
|
|
||
| if thinking_parts := self._process_thinking(choice.message): | ||
|
|
@@ -1242,6 +1269,11 @@ def _process_response( # noqa: C901 | |
| finish_reason: FinishReason | None = None | ||
| provider_details: dict[str, Any] | None = None | ||
| raw_finish_reason = details.reason if (details := response.incomplete_details) else response.status | ||
| if raw_finish_reason == 'content_filter': | ||
| raise ResponseContentFilterError( | ||
| model_name=response.model, | ||
| body=response.model_dump(), | ||
| ) | ||
| if raw_finish_reason: | ||
| provider_details = {'finish_reason': raw_finish_reason} | ||
| finish_reason = _RESPONSES_FINISH_REASON_MAP.get(raw_finish_reason) | ||
|
|
@@ -1398,6 +1430,8 @@ async def _responses_create( # noqa: C901 | |
| ) | ||
| except APIStatusError as e: | ||
| if (status_code := e.status_code) >= 400: | ||
| _check_azure_content_filter(e, self.model_name) | ||
|
|
||
| raise ModelHTTPError(status_code=status_code, model_name=self.model_name, body=e.body) from e | ||
| raise # pragma: lax no cover | ||
| except APIConnectionError as e: | ||
|
|
@@ -1903,6 +1937,10 @@ async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]: | |
| continue | ||
|
|
||
| if raw_finish_reason := choice.finish_reason: | ||
| if raw_finish_reason == 'content_filter': | ||
| raise ResponseContentFilterError( | ||
| model_name=self.model_name, | ||
| ) | ||
| self.finish_reason = self._map_finish_reason(raw_finish_reason) | ||
|
|
||
| if provider_details := self._map_provider_details(chunk): | ||
|
|
@@ -2047,6 +2085,12 @@ async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]: | |
| raw_finish_reason = ( | ||
| details.reason if (details := chunk.response.incomplete_details) else chunk.response.status | ||
| ) | ||
|
|
||
| if raw_finish_reason == 'content_filter': | ||
| raise ResponseContentFilterError( | ||
| model_name=self.model_name, | ||
| ) | ||
|
|
||
| if raw_finish_reason: # pragma: no branch | ||
| self.provider_details = {'finish_reason': raw_finish_reason} | ||
| self.finish_reason = _RESPONSES_FINISH_REASON_MAP.get(raw_finish_reason) | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we name this
RequestContentFilterErrorfor consistency with ourModelRequest/ModelResponse