From 5909f2c2b8f0f24cbaa467123530769070f1e30e Mon Sep 17 00:00:00 2001 From: Pichaya via Codex Date: Thu, 9 Apr 2026 21:15:41 +0700 Subject: [PATCH] Fix missing disagg request id fallback in context responses --- tensorrt_llm/serve/openai_disagg_service.py | 9 +++++-- .../test_openai_disagg_service.py | 27 +++++++++++++++++++ 2 files changed, 34 insertions(+), 2 deletions(-) diff --git a/tensorrt_llm/serve/openai_disagg_service.py b/tensorrt_llm/serve/openai_disagg_service.py index 387d7fa03ec..3f528bc448c 100644 --- a/tensorrt_llm/serve/openai_disagg_service.py +++ b/tensorrt_llm/serve/openai_disagg_service.py @@ -362,8 +362,13 @@ async def _verify_ctx_response(self, ctx_response: UCompletionResponse) -> None: if ctx_response.choices[0].disaggregated_params.ctx_request_id is None: raise ValueError("Invalid disaggregated params in context phase response.") if ctx_response.choices[0].disaggregated_params.disagg_request_id is None: - raise ValueError( - "Invalid disaggregated params in context phase response. disagg_request_id is None" + logger.warning( + "Context phase response is missing disagg_request_id; " + "falling back to ctx_request_id=%s", + ctx_response.choices[0].disaggregated_params.ctx_request_id, + ) + ctx_response.choices[0].disaggregated_params.disagg_request_id = ( + ctx_response.choices[0].disaggregated_params.ctx_request_id ) return ctx_response diff --git a/tests/unittest/disaggregated/test_openai_disagg_service.py b/tests/unittest/disaggregated/test_openai_disagg_service.py index bcbef10b9e2..69cb848f49a 100644 --- a/tests/unittest/disaggregated/test_openai_disagg_service.py +++ b/tests/unittest/disaggregated/test_openai_disagg_service.py @@ -159,6 +159,33 @@ async def _delayed_gen_response(*_args, **_kwargs): ) +@pytest.mark.asyncio +async def test_verify_ctx_response_falls_back_to_ctx_request_id(): + service = _make_service("context_first") + response = CompletionResponse( + model="test-model", + usage=UsageInfo(prompt_tokens=1, completion_tokens=1), + prompt_token_ids=[1, 2, 3], + choices=[ + CompletionResponseChoice( + index=0, + text="ctx-only", + finish_reason="length", + disaggregated_params=DisaggregatedParams( + request_type="context_only", + ctx_request_id=1234, + disagg_request_id=None, + ), + ) + ], + ) + + await service._verify_ctx_response(response) + + assert response.choices[0].disaggregated_params.ctx_request_id == 1234 + assert response.choices[0].disaggregated_params.disagg_request_id == 1234 + + class TestFirstGenLogProbsSerializeRoundtrip: """Roundtrip tests for _serialize/_deserialize_first_gen_log_probs."""