diff --git a/litellm/exceptions.py b/litellm/exceptions.py index abdba09dd8d..88fd6721858 100644 --- a/litellm/exceptions.py +++ b/litellm/exceptions.py @@ -9,7 +9,7 @@ ## LiteLLM versions of the OpenAI Exception Types -from typing import Optional +from typing import Any, Dict, Optional import httpx import openai @@ -1016,6 +1016,35 @@ def __repr__(self): return self.__str__() +class ModifyResponseException(Exception): + """ + Exception raised when a guardrail wants to modify the response. + + This exception carries the synthetic response that should be returned + to the user instead of calling the LLM or instead of the LLM's response. + It should be caught by the proxy and returned with a 200 status code. + + This is a base exception that all guardrails can use to replace responses, + allowing violation messages to be returned as successful responses + rather than errors. + """ + + def __init__( + self, + message: str, + model: str, + request_data: Dict[str, Any], + guardrail_name: Optional[str] = None, + detection_info: Optional[Dict[str, Any]] = None, + ): + self.message = message + self.model = model + self.request_data = request_data + self.guardrail_name = guardrail_name + self.detection_info = detection_info or {} + super().__init__(message) + + class GuardrailInterventionNormalStringError( Exception ): # custom exception to raise when a guardrail intervenes, but we want to return a normal string to the user diff --git a/litellm/integrations/custom_guardrail.py b/litellm/integrations/custom_guardrail.py index 6046f1bb581..42c1028688a 100644 --- a/litellm/integrations/custom_guardrail.py +++ b/litellm/integrations/custom_guardrail.py @@ -41,43 +41,7 @@ dc = DualCache() -class ModifyResponseException(Exception): - """ - Exception raised when a guardrail wants to modify the response. - - This exception carries the synthetic response that should be returned - to the user instead of calling the LLM or instead of the LLM's response. - It should be caught by the proxy and returned with a 200 status code. - - This is a base exception that all guardrails can use to replace responses, - allowing violation messages to be returned as successful responses - rather than errors. - """ - - def __init__( - self, - message: str, - model: str, - request_data: Dict[str, Any], - guardrail_name: Optional[str] = None, - detection_info: Optional[Dict[str, Any]] = None, - ): - """ - Initialize the modify response exception. - - Args: - message: The violation message to return to the user - model: The model that was being called - request_data: The original request data - guardrail_name: Name of the guardrail that raised this exception - detection_info: Additional detection metadata (scores, rules, etc.) - """ - self.message = message - self.model = model - self.request_data = request_data - self.guardrail_name = guardrail_name - self.detection_info = detection_info or {} - super().__init__(message) +from litellm.exceptions import ModifyResponseException as ModifyResponseException class CustomGuardrail(CustomLogger): @@ -417,7 +381,9 @@ def should_run_guardrail( """ requested_guardrails = self.get_guardrail_from_metadata(data) disable_global_guardrail = self.get_disable_global_guardrail(data) - opted_out_global_guardrails = self.get_opted_out_global_guardrails_from_metadata(data) + opted_out_global_guardrails = ( + self.get_opted_out_global_guardrails_from_metadata(data) + ) verbose_logger.debug( "inside should_run_guardrail for guardrail=%s event_type= %s guardrail_supported_event_hooks= %s requested_guardrails= %s self.default_on= %s", self.guardrail_name, @@ -426,7 +392,10 @@ def should_run_guardrail( requested_guardrails, self.default_on, ) - if self.default_on is True and self.guardrail_name in opted_out_global_guardrails: + if ( + self.default_on is True + and self.guardrail_name in opted_out_global_guardrails + ): return False if self.default_on is True and disable_global_guardrail is not True: diff --git a/litellm/proxy/guardrails/guardrail_hooks/unified_guardrail/unified_guardrail.py b/litellm/proxy/guardrails/guardrail_hooks/unified_guardrail/unified_guardrail.py index a1623121da5..f29eab47cf9 100644 --- a/litellm/proxy/guardrails/guardrail_hooks/unified_guardrail/unified_guardrail.py +++ b/litellm/proxy/guardrails/guardrail_hooks/unified_guardrail/unified_guardrail.py @@ -227,6 +227,16 @@ async def async_post_call_success_hook( if call_type is None: call_type = _infer_call_type(call_type=None, completion_response=response) # type: ignore + # Fallback: resolve call_type from logging_obj for pass-through endpoints + if call_type is None: + litellm_logging_obj = data.get("litellm_logging_obj") + if ( + litellm_logging_obj is not None + and getattr(litellm_logging_obj, "call_type", None) + == CallTypes.pass_through.value + ): + call_type = CallTypes.pass_through.value + if call_type is None: return response diff --git a/litellm/proxy/pass_through_endpoints/pass_through_endpoints.py b/litellm/proxy/pass_through_endpoints/pass_through_endpoints.py index d582240395f..8104c587f65 100644 --- a/litellm/proxy/pass_through_endpoints/pass_through_endpoints.py +++ b/litellm/proxy/pass_through_endpoints/pass_through_endpoints.py @@ -635,6 +635,7 @@ async def pass_through_request( # noqa: PLR0915 custom_llm_provider: Optional field - custom LLM provider for the endpoint guardrails_config: Optional field - guardrails configuration for passthrough endpoint """ + from litellm.exceptions import ModifyResponseException from litellm.litellm_core_utils.litellm_logging import Logging from litellm.proxy.pass_through_endpoints.passthrough_guardrails import ( PassthroughGuardrailHandler, @@ -915,8 +916,41 @@ async def pass_through_request( # noqa: PLR0915 content = await response.aread() - ## LOG SUCCESS + ## POST-CALL GUARDRAILS ## + _content_modified = False response_body: Optional[dict] = get_response_body(response) + if response_body is not None and guardrails_to_run: + # Build an enriched data dict: _parsed_body has been stripped of + # `metadata` by both pre_call_hook and _init_kwargs_for_pass_through_endpoint, + # so we re-attach the configured guardrails here so should_run_guardrail + # sees them. + hook_data = dict(_parsed_body or {}) + existing_metadata = hook_data.get("metadata") + if not isinstance(existing_metadata, dict): + existing_metadata = {} + hook_data["metadata"] = { + **existing_metadata, + "guardrails": guardrails_to_run, + } + response_body = await proxy_logging_obj.post_call_success_hook( + data=hook_data, + user_api_key_dict=user_api_key_dict, + response=response_body, # type: ignore[arg-type] + ) + if isinstance(response_body, dict): + content = json.dumps(response_body).encode("utf-8") + _content_modified = True + else: + verbose_proxy_logger.debug( + "pass_through_endpoint: post_call_success_hook returned %s, expected dict — using original response", + type(response_body).__name__, + ) + elif response_body is None: + verbose_proxy_logger.debug( + "pass_through_endpoint: response body not JSON-parseable, skipping post-call guardrails" + ) + + ## LOG SUCCESS passthrough_logging_payload["response_body"] = response_body end_time = datetime.now() asyncio.create_task( @@ -944,13 +978,47 @@ async def pass_through_request( # noqa: PLR0915 api_base=str(url._uri_reference), ) + response_headers = HttpPassThroughEndpointHelpers.get_response_headers( + headers=response.headers, + custom_headers=custom_headers, + ) + if _content_modified: + response_headers.pop("content-length", None) + return Response( content=content, status_code=response.status_code, - headers=HttpPassThroughEndpointHelpers.get_response_headers( - headers=response.headers, - custom_headers=custom_headers, - ), + headers=response_headers, + ) + except ModifyResponseException as e: + verbose_proxy_logger.info( + "pass_through_endpoint: Guardrail %s modified response: %s", + e.guardrail_name, + str(e.message or "")[:200], + ) + try: + await proxy_logging_obj.post_call_failure_hook( + user_api_key_dict=user_api_key_dict, + original_exception=e, + request_data=e.request_data, + ) + except Exception: + verbose_proxy_logger.warning( + "pass_through_endpoint: post_call_failure_hook raised during guardrail block", + exc_info=True, + ) + error_body = { + "error": { + "message": e.message or "Response blocked by guardrail", + "type": "content_filter", + "guardrail_name": e.guardrail_name, + "model": e.model, + } + } + return Response( + content=json.dumps(error_body), + status_code=200, + media_type="application/json", ) except Exception as e: custom_headers = ProxyBaseLLMRequestProcessing.get_custom_headers( diff --git a/tests/test_litellm/proxy/pass_through_endpoints/test_passthrough_post_call_guardrails.py b/tests/test_litellm/proxy/pass_through_endpoints/test_passthrough_post_call_guardrails.py new file mode 100644 index 00000000000..f061434a971 --- /dev/null +++ b/tests/test_litellm/proxy/pass_through_endpoints/test_passthrough_post_call_guardrails.py @@ -0,0 +1,276 @@ +""" +Tests for post-call guardrail invocation on pass-through endpoints. + +Verifies that apply_guardrail(input_type="response") is called for +non-streaming pass-through responses. Addresses issue #20270. +""" + +import json +import sys +from contextlib import ExitStack +from unittest.mock import AsyncMock, MagicMock, patch + +import httpx +import pytest + +from litellm.integrations.custom_guardrail import ( + CustomGuardrail, + ModifyResponseException, +) + +_PT_MOD = "litellm.proxy.pass_through_endpoints.pass_through_endpoints" +_COLLECT = "litellm.proxy.pass_through_endpoints.passthrough_guardrails.PassthroughGuardrailHandler.collect_guardrails" + +_GEMINI_RESPONSE = { + "candidates": [ + { + "content": { + "role": "model", + "parts": [{"text": "Hello"}], + } + } + ] +} + + +def _make_user_api_key_dict(**overrides): + d = MagicMock() + d.api_key = "sk-test" + d.user_id = "user-1" + d.team_id = "team-1" + d.org_id = None + d.request_route = "/vertex_ai/v1/projects/p/locations/l/publishers/google/models/gemini:generateContent" + for k, v in overrides.items(): + setattr(d, k, v) + return d + + +def _make_httpx_response(body: dict, status_code: int = 200) -> httpx.Response: + content = json.dumps(body).encode("utf-8") + return httpx.Response( + status_code=status_code, + headers={"content-type": "application/json"}, + content=content, + request=httpx.Request("POST", "https://example.com/v1/generateContent"), + ) + + +def _make_mock_request(): + mock_request = MagicMock() + mock_request.method = "POST" + mock_request.query_params = {} + mock_request.headers = MagicMock() + mock_request.headers.copy.return_value = {} + return mock_request + + +def _ensure_proxy_server_mock(): + """Insert a mock proxy_server module if the real one can't import.""" + key = "litellm.proxy.proxy_server" + if key not in sys.modules: + mock_mod = MagicMock() + mock_mod.proxy_logging_obj = MagicMock() + sys.modules[key] = mock_mod + import litellm.proxy + + if not hasattr(litellm.proxy, "proxy_server"): + litellm.proxy.proxy_server = sys.modules[key] + + +_ensure_proxy_server_mock() + +from litellm.proxy.pass_through_endpoints.pass_through_endpoints import ( + pass_through_request, +) + + +def _common_patches(mock_proxy_logging, mock_response): + """Return a combined context manager for the patches shared by all tests.""" + mock_async_client = AsyncMock() + mock_async_client_obj = MagicMock() + mock_async_client_obj.client = mock_async_client + + mock_pt_logging = MagicMock() + mock_pt_logging.pass_through_async_success_handler = AsyncMock() + + patches = [ + patch( + f"{_PT_MOD}.HttpPassThroughEndpointHelpers.non_streaming_http_request_handler", + new_callable=AsyncMock, + return_value=mock_response, + ), + patch(f"{_PT_MOD}._is_streaming_response", return_value=False), + patch("litellm.proxy.proxy_server.proxy_logging_obj", mock_proxy_logging), + patch(f"{_PT_MOD}.pass_through_endpoint_logging", mock_pt_logging), + patch(f"{_PT_MOD}.get_async_httpx_client", return_value=mock_async_client_obj), + patch(f"{_PT_MOD}._read_request_body", new_callable=AsyncMock, return_value={}), + patch(f"{_PT_MOD}._safe_get_request_headers", return_value={}), + ] + + stack = ExitStack() + for p in patches: + stack.enter_context(p) + return stack + + +@pytest.mark.asyncio +class TestPassthroughPostCallGuardrails: + + @patch(_COLLECT, return_value=["rubrik"]) + async def test_post_call_success_hook_called_when_guardrails_configured( + self, + mock_collect, + ): + """post_call_success_hook should fire when guardrails are configured.""" + mock_response = _make_httpx_response(_GEMINI_RESPONSE) + + mock_proxy_logging = MagicMock() + mock_proxy_logging.pre_call_hook = AsyncMock(return_value={}) + mock_proxy_logging.post_call_success_hook = AsyncMock( + return_value=_GEMINI_RESPONSE + ) + + with _common_patches(mock_proxy_logging, mock_response): + await pass_through_request( + request=_make_mock_request(), + target="https://example.com/v1/generateContent", + custom_headers={"Content-Type": "application/json"}, + user_api_key_dict=_make_user_api_key_dict(), + stream=False, + ) + + mock_proxy_logging.post_call_success_hook.assert_awaited_once() + call_kwargs = mock_proxy_logging.post_call_success_hook.call_args + assert call_kwargs.kwargs["response"] == _GEMINI_RESPONSE + + @patch(_COLLECT, return_value=[]) + async def test_post_call_success_hook_skipped_when_no_guardrails( + self, + mock_collect, + ): + """post_call_success_hook should NOT fire when no guardrails are configured.""" + mock_response = _make_httpx_response(_GEMINI_RESPONSE) + + mock_proxy_logging = MagicMock() + mock_proxy_logging.pre_call_hook = AsyncMock(return_value={}) + mock_proxy_logging.post_call_success_hook = AsyncMock() + + with _common_patches(mock_proxy_logging, mock_response): + result = await pass_through_request( + request=_make_mock_request(), + target="https://example.com/v1/generateContent", + custom_headers={"Content-Type": "application/json"}, + user_api_key_dict=_make_user_api_key_dict(), + stream=False, + ) + + mock_proxy_logging.post_call_success_hook.assert_not_awaited() + assert result.status_code == 200 + + @patch(_COLLECT, return_value=["rubrik"]) + async def test_modify_response_exception_returns_error( + self, + mock_collect, + ): + """ModifyResponseException from guardrail should return 200 with provider-agnostic error.""" + response_body = { + "candidates": [ + { + "content": { + "role": "model", + "parts": [ + {"functionCall": {"name": "dangerous_tool", "args": {}}} + ], + } + } + ] + } + mock_response = _make_httpx_response(response_body) + + mock_proxy_logging = MagicMock() + mock_proxy_logging.pre_call_hook = AsyncMock(return_value={}) + mock_proxy_logging.post_call_success_hook = AsyncMock( + side_effect=ModifyResponseException( + message="Tool dangerous_tool blocked by policy", + model="gemini-2.0-flash", + request_data={}, + guardrail_name="rubrik", + ) + ) + mock_proxy_logging.post_call_failure_hook = AsyncMock() + + with _common_patches(mock_proxy_logging, mock_response): + result = await pass_through_request( + request=_make_mock_request(), + target="https://example.com/v1/generateContent", + custom_headers={"Content-Type": "application/json"}, + user_api_key_dict=_make_user_api_key_dict(), + stream=False, + ) + + mock_proxy_logging.post_call_failure_hook.assert_awaited_once() + assert result.status_code == 200 + body = json.loads(result.body) + assert body["error"]["type"] == "content_filter" + assert body["error"]["message"] == "Tool dangerous_tool blocked by policy" + assert body["error"]["guardrail_name"] == "rubrik" + assert body["error"]["model"] == "gemini-2.0-flash" + + +@pytest.mark.asyncio +class TestUnifiedGuardrailCallTypeResolution: + + async def test_pass_through_call_type_resolved_from_logging_obj(self): + """Unified guardrail should resolve call_type from logging_obj for pass-through.""" + from litellm.proxy.guardrails.guardrail_hooks.unified_guardrail.unified_guardrail import ( + UnifiedLLMGuardrails, + ) + + unified = UnifiedLLMGuardrails() + + mock_guardrail = MagicMock(spec=CustomGuardrail) + mock_guardrail.guardrail_name = "test-guardrail" + mock_guardrail.should_run_guardrail.return_value = True + + mock_logging_obj = MagicMock() + mock_logging_obj.call_type = "pass_through_endpoint" + + user_api_key_dict = _make_user_api_key_dict() + + data = { + "guardrail_to_apply": mock_guardrail, + "litellm_logging_obj": mock_logging_obj, + } + + response_body = {"candidates": [{"content": {"parts": [{"text": "hello"}]}}]} + + with patch( + "litellm.proxy.guardrails.guardrail_hooks.unified_guardrail.unified_guardrail.load_guardrail_translation_mappings" + ) as mock_load: + mock_handler_instance = AsyncMock() + mock_handler_instance.process_output_response = AsyncMock( + return_value=response_body + ) + mock_handler_class = MagicMock(return_value=mock_handler_instance) + + from litellm.types.utils import CallTypes + + mock_load.return_value = {CallTypes.pass_through: mock_handler_class} + + result = await unified.async_post_call_success_hook( + data=data, + user_api_key_dict=user_api_key_dict, + response=response_body, + ) + + mock_handler_instance.process_output_response.assert_awaited_once() + + +def test_modify_response_exception_importable_from_both_paths(): + """ModifyResponseException re-export from custom_guardrail must stay in sync.""" + from litellm.exceptions import ModifyResponseException as FromExceptions + from litellm.integrations.custom_guardrail import ( + ModifyResponseException as FromGuardrail, + ) + + assert FromExceptions is FromGuardrail