Skip to content

Commit 244bdff

Browse files
author
Krish Dholakia
authored
Merge pull request BerriAI#23509 from michelligabriele/fix/pass-through-duplicate-failure-logs
fix(proxy): prevent duplicate callback logs for pass-through endpoint failures
2 parents 0d7425a + bfcba21 commit 244bdff

File tree

5 files changed

+221
-5
lines changed

5 files changed

+221
-5
lines changed

litellm/litellm_core_utils/litellm_logging.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2956,7 +2956,10 @@ def failure_handler( # noqa: PLR0915
29562956
callback_func=callback,
29572957
)
29582958
if (
2959-
isinstance(callback, CustomLogger) and is_sync_request
2959+
isinstance(callback, CustomLogger)
2960+
and is_sync_request
2961+
and self.call_type
2962+
!= CallTypes.pass_through.value
29602963
): # custom logger class
29612964
callback.log_failure_event(
29622965
start_time=start_time,

litellm/proxy/pass_through_endpoints/pass_through_endpoints.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -961,6 +961,8 @@ async def pass_through_request( # noqa: PLR0915
961961
if kwargs:
962962
for key, value in kwargs.items():
963963
request_payload[key] = value
964+
if logging_obj is not None:
965+
request_payload["litellm_logging_obj"] = logging_obj
964966

965967
if (
966968
"model" not in request_payload
@@ -1703,6 +1705,8 @@ def __init__(self, target_url: str):
17031705
if kwargs:
17041706
for key, value in kwargs.items():
17051707
request_payload[key] = value
1708+
if logging_obj is not None:
1709+
request_payload["litellm_logging_obj"] = logging_obj
17061710

17071711
# Log the connection failure using the same pattern as HTTP
17081712
await proxy_logging_obj.post_call_failure_hook(
@@ -1729,6 +1733,8 @@ def __init__(self, target_url: str):
17291733
if kwargs:
17301734
for key, value in kwargs.items():
17311735
request_payload[key] = value
1736+
if logging_obj is not None:
1737+
request_payload["litellm_logging_obj"] = logging_obj
17321738

17331739
# Log the unexpected error using the same pattern as HTTP
17341740
await proxy_logging_obj.post_call_failure_hook(

litellm/proxy/utils.py

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1729,6 +1729,9 @@ async def post_call_failure_hook(
17291729
original_exception=original_exception,
17301730
)
17311731

1732+
# Remove before callbacks iterate — not serialisable
1733+
request_data.pop("litellm_logging_obj", None)
1734+
17321735
# Track the first HTTPException returned or raised by any callback
17331736
transformed_exception: Optional[HTTPException] = None
17341737

@@ -1849,7 +1852,7 @@ async def _handle_logging_proxy_only_error(
18491852
for k, v in request_data.items():
18501853
if k in litellm_param_keys:
18511854
_litellm_params[k] = v
1852-
elif k != "model" and k != "user":
1855+
elif k not in ("model", "user", "litellm_logging_obj"):
18531856
_optional_params[k] = v
18541857

18551858
litellm_logging_obj.update_environment_variables(
@@ -1865,15 +1868,23 @@ async def _handle_logging_proxy_only_error(
18651868
):
18661869
input = request_data["messages"]
18671870
litellm_logging_obj.model_call_details["messages"] = input
1868-
litellm_logging_obj.call_type = CallTypes.acompletion.value
1871+
if litellm_logging_obj.call_type != CallTypes.pass_through.value:
1872+
litellm_logging_obj.call_type = CallTypes.acompletion.value
18691873
elif "prompt" in request_data and isinstance(request_data["prompt"], str):
18701874
input = request_data["prompt"]
18711875
litellm_logging_obj.model_call_details["prompt"] = input
1872-
litellm_logging_obj.call_type = CallTypes.atext_completion.value
1876+
if litellm_logging_obj.call_type != CallTypes.pass_through.value:
1877+
litellm_logging_obj.call_type = CallTypes.atext_completion.value
18731878
elif "input" in request_data and isinstance(request_data["input"], list):
18741879
input = request_data["input"]
18751880
litellm_logging_obj.model_call_details["input"] = input
1876-
litellm_logging_obj.call_type = CallTypes.aembedding.value
1881+
if litellm_logging_obj.call_type != CallTypes.pass_through.value:
1882+
litellm_logging_obj.call_type = CallTypes.aembedding.value
1883+
# Pass-through endpoints are logged via the callback loop's
1884+
# async_post_call_failure_hook — skip pre_call and failure handlers.
1885+
if litellm_logging_obj.call_type == CallTypes.pass_through.value:
1886+
return
1887+
18771888
litellm_logging_obj.pre_call(
18781889
input=input,
18791890
api_key="",

tests/proxy_unit_tests/test_proxy_utils.py

Lines changed: 131 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2385,3 +2385,134 @@ async def async_moderation_hook(self, data, user_api_key_dict, call_type):
23852385
assert "Guardrail violation detected!" in str(exc_info.value)
23862386
finally:
23872387
litellm.callbacks = original_callbacks
2388+
2389+
2390+
@pytest.mark.asyncio
2391+
async def test_handle_logging_proxy_only_error_preserves_pass_through_call_type():
2392+
"""Ensure _handle_logging_proxy_only_error does not overwrite call_type
2393+
when the logging object is already marked as pass_through_endpoint.
2394+
"""
2395+
from litellm.caching.caching import DualCache
2396+
from litellm.litellm_core_utils.litellm_logging import Logging
2397+
from litellm.proxy.utils import ProxyLogging
2398+
from litellm.types.utils import CallTypes
2399+
2400+
logging_obj = Logging(
2401+
model="unknown",
2402+
messages=[{"role": "user", "content": "test"}],
2403+
stream=False,
2404+
call_type="pass_through_endpoint",
2405+
start_time=datetime.now(),
2406+
litellm_call_id="test-call-id",
2407+
function_id="test-function-id",
2408+
)
2409+
2410+
request_data = {
2411+
"litellm_logging_obj": logging_obj,
2412+
"messages": [{"role": "user", "content": "test"}],
2413+
"model": "claude-3-5-sonnet",
2414+
}
2415+
2416+
cache = DualCache()
2417+
proxy_logging = ProxyLogging(user_api_key_cache=cache)
2418+
2419+
with patch.object(logging_obj, "async_failure_handler", new_callable=AsyncMock):
2420+
with patch.object(logging_obj, "failure_handler"):
2421+
await proxy_logging._handle_logging_proxy_only_error(
2422+
request_data=request_data,
2423+
user_api_key_dict=UserAPIKeyAuth(
2424+
api_key="test_key", token="test_token"
2425+
),
2426+
original_exception=Exception("test error"),
2427+
)
2428+
2429+
assert logging_obj.call_type == CallTypes.pass_through.value
2430+
2431+
2432+
@pytest.mark.asyncio
2433+
async def test_litellm_logging_obj_excluded_from_optional_params():
2434+
"""Ensure litellm_logging_obj is excluded from _optional_params to prevent
2435+
circular references in model_call_details.
2436+
"""
2437+
from litellm.caching.caching import DualCache
2438+
from litellm.litellm_core_utils.litellm_logging import Logging
2439+
from litellm.proxy.utils import ProxyLogging
2440+
2441+
logging_obj = Logging(
2442+
model="unknown",
2443+
messages=[{"role": "user", "content": "test"}],
2444+
stream=False,
2445+
call_type="pass_through_endpoint",
2446+
start_time=datetime.now(),
2447+
litellm_call_id="test-call-id",
2448+
function_id="test-function-id",
2449+
)
2450+
2451+
request_data = {
2452+
"litellm_logging_obj": logging_obj,
2453+
"messages": [{"role": "user", "content": "test"}],
2454+
"model": "claude-3-5-sonnet",
2455+
}
2456+
2457+
cache = DualCache()
2458+
proxy_logging = ProxyLogging(user_api_key_cache=cache)
2459+
2460+
with patch.object(logging_obj, "async_failure_handler", new_callable=AsyncMock):
2461+
with patch.object(logging_obj, "failure_handler"):
2462+
await proxy_logging._handle_logging_proxy_only_error(
2463+
request_data=request_data,
2464+
user_api_key_dict=UserAPIKeyAuth(
2465+
api_key="test_key", token="test_token"
2466+
),
2467+
original_exception=Exception("test error"),
2468+
)
2469+
2470+
assert "litellm_logging_obj" not in logging_obj.model_call_details
2471+
2472+
2473+
@pytest.mark.asyncio
2474+
async def test_handle_logging_proxy_only_error_skips_handlers_for_pass_through():
2475+
"""Ensure _handle_logging_proxy_only_error skips async_failure_handler and
2476+
failure_handler for pass-through endpoint errors, so only
2477+
async_post_call_failure_hook fires (avoiding duplicate logs).
2478+
2479+
Regression test for duplicate Datadog/Arize logs on pass-through failures.
2480+
"""
2481+
from litellm.caching.caching import DualCache
2482+
from litellm.litellm_core_utils.litellm_logging import Logging
2483+
from litellm.proxy.utils import ProxyLogging
2484+
from litellm.types.utils import CallTypes
2485+
2486+
logging_obj = Logging(
2487+
model="unknown",
2488+
messages=[{"role": "user", "content": "test"}],
2489+
stream=False,
2490+
call_type="pass_through_endpoint",
2491+
start_time=datetime.now(),
2492+
litellm_call_id="test-call-id",
2493+
function_id="test-function-id",
2494+
)
2495+
2496+
cache = DualCache()
2497+
proxy_logging = ProxyLogging(user_api_key_cache=cache)
2498+
2499+
request_data = {
2500+
"litellm_logging_obj": logging_obj,
2501+
"messages": [{"role": "user", "content": "test"}],
2502+
"model": "claude-3-5-sonnet",
2503+
}
2504+
2505+
with patch.object(logging_obj, "async_failure_handler", new_callable=AsyncMock) as mock_async:
2506+
with patch.object(logging_obj, "failure_handler") as mock_sync:
2507+
await proxy_logging._handle_logging_proxy_only_error(
2508+
request_data=request_data,
2509+
user_api_key_dict=UserAPIKeyAuth(
2510+
api_key="test_key", token="test_token"
2511+
),
2512+
original_exception=Exception("test error"),
2513+
)
2514+
2515+
# Neither handler should fire for pass-through requests
2516+
mock_async.assert_not_called()
2517+
mock_sync.assert_not_called()
2518+
assert logging_obj.call_type == CallTypes.pass_through.value

tests/test_litellm/litellm_core_utils/test_litellm_logging.py

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1969,3 +1969,68 @@ def test_function_setup_empty_metadata_falls_back_to_litellm_metadata():
19691969
assert metadata is not None
19701970
assert metadata.get("user_api_key_hash") == "sk-hashed-empty-test"
19711971
assert metadata.get("user_api_key_team_id") == "team-empty-test"
1972+
1973+
1974+
def test_failure_handler_skips_sync_callbacks_for_pass_through_requests(logging_obj):
1975+
"""Ensure sync failure callbacks are skipped for pass-through endpoint requests.
1976+
1977+
Regression test for duplicate Datadog/Arize logs on pass-through endpoint failures.
1978+
The async_failure_handler fires async_log_failure_event; the sync failure_handler
1979+
must NOT also fire log_failure_event for pass-through requests.
1980+
"""
1981+
from litellm.integrations.custom_logger import CustomLogger
1982+
from litellm.types.utils import CallTypes
1983+
1984+
class DummyLogger(CustomLogger):
1985+
pass
1986+
1987+
logging_obj.call_type = CallTypes.pass_through.value
1988+
logging_obj.stream = False
1989+
logging_obj.model_call_details["litellm_params"] = {}
1990+
logging_obj.litellm_params = {}
1991+
1992+
dummy_logger = DummyLogger()
1993+
dummy_logger.log_failure_event = MagicMock()
1994+
1995+
with patch.object(
1996+
logging_obj,
1997+
"get_combined_callback_list",
1998+
return_value=[dummy_logger],
1999+
):
2000+
logging_obj.failure_handler(
2001+
exception=Exception("test error"),
2002+
traceback_exception="",
2003+
)
2004+
2005+
dummy_logger.log_failure_event.assert_not_called()
2006+
2007+
2008+
@pytest.mark.parametrize("call_type", ["completion", "acompletion"])
2009+
def test_failure_handler_runs_sync_callbacks_for_non_pass_through_requests(
2010+
logging_obj, call_type
2011+
):
2012+
"""Ensure sync failure callbacks still fire for normal (non-pass-through) requests."""
2013+
from litellm.integrations.custom_logger import CustomLogger
2014+
2015+
class DummyLogger(CustomLogger):
2016+
pass
2017+
2018+
logging_obj.call_type = call_type
2019+
logging_obj.stream = False
2020+
logging_obj.model_call_details["litellm_params"] = {}
2021+
logging_obj.litellm_params = {}
2022+
2023+
dummy_logger = DummyLogger()
2024+
dummy_logger.log_failure_event = MagicMock()
2025+
2026+
with patch.object(
2027+
logging_obj,
2028+
"get_combined_callback_list",
2029+
return_value=[dummy_logger],
2030+
):
2031+
logging_obj.failure_handler(
2032+
exception=Exception("test error"),
2033+
traceback_exception="",
2034+
)
2035+
2036+
dummy_logger.log_failure_event.assert_called_once()

0 commit comments

Comments
 (0)