Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 30 additions & 1 deletion tensorrt_llm/_torch/disaggregation/transceiver.py
Original file line number Diff line number Diff line change
Expand Up @@ -310,8 +310,37 @@ def respond_and_send_async(self, req: LlmRequest):
)
self._send_reqs[rid] = req

@nvtx_range("KvCacheTransceiverV2.request_and_receive_sync")
def request_and_receive_sync(self, req: LlmRequest):
raise NotImplementedError("request_and_receive_sync is not implemented")
rid = get_unique_rid(req)
if rid in self._recv_sessions:
logger.warning(
f"request_and_receive_sync: rid={rid} already has a recv session, skipping"
)
return
req.state = LlmRequestState.DISAGG_GENERATION_TRANS_IN_PROGRESS
session = None
try:
session = self._transfer_worker.create_rx_session(req)
self._recv_sessions[rid] = session
self._recv_reqs[rid] = req
session.receive(self._create_kv_slice(req))
result = session.wait_complete(blocking=True)

if result == WaitResult.COMPLETED:
if self._need_aux_transfer(req):
self._apply_aux(session, req)
req.state = LlmRequestState.DISAGG_GENERATION_TRANS_COMPLETE
else:
req.state = LlmRequestState.DISAGG_TRANS_ERROR
except Exception:
req.state = LlmRequestState.DISAGG_TRANS_ERROR
raise
finally:
if session is not None:
session.close()
self._recv_sessions.pop(rid, None)
self._recv_reqs.pop(rid, None)

@nvtx_range("KvCacheTransceiverV2.request_and_receive_async")
def request_and_receive_async(self, req: LlmRequest):
Expand Down
13 changes: 12 additions & 1 deletion tensorrt_llm/_torch/pyexecutor/scheduler/scheduler_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,8 +230,19 @@ def _schedule_loop(self, active_requests, inflight_request_ids):

req_state_value = req.state_value

# Disagg gen init bypasses normal state gating (same as C++ / V1 scheduler)
# Disagg gen init bypasses normal state gating (same as C++ / V1 scheduler),
# but the V2 scheduler owns inline KV allocation so we must allocate here.
# V1 defers allocation to prepare_resources; V2 prepare_resources is a no-op
# for the primary manager, so allocation must happen in the scheduling loop.
if req_state_value == self._disagg_gen_init_state_value:
if not self.kv_cache_manager.prepare_context(req):
req_it += 1
continue
if not self.kv_cache_manager.resize_context(
req, req.context_remaining_length + get_draft_token_length(req)
):
req_it += 1
continue
disagg_candidates.append(req)
req_it += 1
continue
Expand Down
188 changes: 188 additions & 0 deletions tests/integration/defs/accuracy/test_disaggregated_serving.py
Original file line number Diff line number Diff line change
Expand Up @@ -633,6 +633,50 @@ def test_auto_dtype(self, ctx_disable_overlap_scheduler,
self.MODEL_PATH) as llm:
run_accuracy_test(llm, self.MODEL_NAME, ["MMLU", "GSM8K"])

@skip_pre_hopper
@pytest.mark.skip_less_device(2)
def test_kv_cache_v2_nixl_python(self):
"""Test with use_kv_cache_manager_v2=True, block_reuse=False, backend=NIXL, transceiver_runtime=PYTHON."""
ctx_server_config = {
"disable_overlap_scheduler": True,
"kv_cache_config": {
"enable_block_reuse": False,
"use_kv_cache_manager_v2": True
},
"cache_transceiver_config": {
"backend": "NIXL",
"transceiver_runtime": "PYTHON"
}
}
gen_server_config = {
"disable_overlap_scheduler": False,
"kv_cache_config": {
"enable_block_reuse": False,
"use_kv_cache_manager_v2": True
},
"cache_transceiver_config": {
"backend": "NIXL",
"transceiver_runtime": "PYTHON"
}
}
disaggregated_server_config = {
"hostname": "localhost",
"port": 8000,
"backend": "pytorch",
"context_servers": {
"num_instances": 1,
"urls": ["localhost:8001"]
},
"generation_servers": {
"num_instances": 1,
"urls": ["localhost:8002"]
}
}
with launch_disaggregated_llm(disaggregated_server_config,
ctx_server_config, gen_server_config,
self.MODEL_PATH) as llm:
run_accuracy_test(llm, self.MODEL_NAME, ["GSM8K"])

@pytest.mark.skip_less_device(2)
def test_ngram(self):
speculative_decoding_config = {
Expand Down Expand Up @@ -952,6 +996,59 @@ def test_nixl_backend(self):
self.MODEL_PATH) as llm:
run_accuracy_test(llm, self.MODEL_NAME, ["MMLU", "GSM8K"])

@pytest.mark.skip_less_device(2)
@pytest.mark.skip_less_device_memory(60000)
@skip_no_hopper
def test_gen_only_sync(self):
"""Test gen-only synchronous KV transfer path with NIXL Python transceiver.

Sets TRTLLM_DISABLE_KV_CACHE_TRANSFER_OVERLAP=1 so the gen worker calls
request_and_receive_sync instead of the async path, mirroring the
gen-only benchmark mode used for disagg serving performance measurement.
TLLM_BENCHMARK_REQ_QUEUES_SIZE pre-saturates the gen queue with N requests
before the first forward pass (one-time warmup), then processing continues
normally. Accuracy must be identical to the standard async path.
"""
ctx_server_config = {
"disable_overlap_scheduler": True,
"cache_transceiver_config": {
"backend": "NIXL",
"transceiver_runtime": "PYTHON",
"max_tokens_in_buffer": 4096,
},
}
gen_server_config = {
"disable_overlap_scheduler": True,
"cache_transceiver_config": {
"backend": "NIXL",
"transceiver_runtime": "PYTHON",
"max_tokens_in_buffer": 4096,
},
}
disaggregated_server_config = {
"hostname": "localhost",
"backend": "pytorch",
"context_servers": {
"num_instances": 1
},
"generation_servers": {
"num_instances": 1
},
}
extra_env = {
# Use synchronous receive: request_and_receive_sync instead of async.
"TRTLLM_DISABLE_KV_CACHE_TRANSFER_OVERLAP": "1",
# Pre-saturate the gen queue with 4 requests before the first
# forward pass (matches gen-only benchmark setup).
"TLLM_BENCHMARK_REQ_QUEUES_SIZE": "4",
}
with launch_disaggregated_llm(disaggregated_server_config,
ctx_server_config,
gen_server_config,
self.MODEL_PATH,
extra_env=extra_env) as llm:
run_accuracy_test(llm, self.MODEL_NAME, ["GSM8K"])

@pytest.mark.skip_less_device(8)
@parametrize_with_ids("overlap_scheduler", [True, False])
@parametrize_with_ids("mtp_nextn", [0, 2])
Expand Down Expand Up @@ -1141,6 +1238,51 @@ def test_guided_decoding(self, backend: str, mtp_nextn: int, mocker):
self.MODEL_PATH) as llm:
run_accuracy_test(llm, self.MODEL_NAME, ["JsonModeEval"])

@pytest.mark.skip_less_device(2)
@pytest.mark.skip_less_device_memory(60000)
@skip_pre_hopper
def test_kv_cache_v2_nixl_python(self):
"""Test with use_kv_cache_manager_v2=True, block_reuse=False, backend=NIXL, transceiver_runtime=PYTHON."""
ctx_server_config = {
"disable_overlap_scheduler": True,
"kv_cache_config": {
"enable_block_reuse": False,
"use_kv_cache_manager_v2": True
},
"cache_transceiver_config": {
"backend": "NIXL",
"transceiver_runtime": "PYTHON"
}
}
gen_server_config = {
"disable_overlap_scheduler": True,
"kv_cache_config": {
"enable_block_reuse": False,
"use_kv_cache_manager_v2": True
},
"cache_transceiver_config": {
"backend": "NIXL",
"transceiver_runtime": "PYTHON"
}
}
disaggregated_server_config = {
"hostname": "localhost",
"port": 8000,
"backend": "pytorch",
"context_servers": {
"num_instances": 1,
"urls": ["localhost:8001"]
},
"generation_servers": {
"num_instances": 1,
"urls": ["localhost:8002"]
}
}
with launch_disaggregated_llm(disaggregated_server_config,
ctx_server_config, gen_server_config,
self.MODEL_PATH) as llm:
run_accuracy_test(llm, self.MODEL_NAME, ["GSM8K"])


@pytest.mark.timeout(DEFAULT_TEST_TIMEOUT)
class TestGemma3_1BInstruct(LlmapiAccuracyTestHarness):
Expand Down Expand Up @@ -1193,6 +1335,52 @@ def test_auto_dtype(self, block_reuse):
self.MODEL_PATH) as llm:
run_accuracy_test(llm, self.MODEL_NAME, ["MMLU", "GSM8K"])

@pytest.mark.skip_less_device(2)
@skip_pre_hopper
def test_kv_cache_v2_nixl_python(self):
"""Test with use_kv_cache_manager_v2=True, block_reuse=False, backend=NIXL, transceiver_runtime=PYTHON."""
ctx_server_config = {
"disable_overlap_scheduler": True,
"cuda_graph_config": None,
"kv_cache_config": {
"enable_block_reuse": False,
"use_kv_cache_manager_v2": True
},
"cache_transceiver_config": {
"backend": "NIXL",
"transceiver_runtime": "PYTHON"
}
}
gen_server_config = {
"disable_overlap_scheduler": True,
"cuda_graph_config": None,
"kv_cache_config": {
"enable_block_reuse": False,
"use_kv_cache_manager_v2": True
},
"cache_transceiver_config": {
"backend": "NIXL",
"transceiver_runtime": "PYTHON"
}
}
disaggregated_server_config = {
"hostname": "localhost",
"port": 8000,
"backend": "pytorch",
"context_servers": {
"num_instances": 1,
"urls": ["localhost:8001"]
},
"generation_servers": {
"num_instances": 1,
"urls": ["localhost:8002"]
}
}
with launch_disaggregated_llm(disaggregated_server_config,
ctx_server_config, gen_server_config,
self.MODEL_PATH) as llm:
run_accuracy_test(llm, self.MODEL_NAME, ["MMLU", "GSM8K"])


@skip_pre_blackwell
@pytest.mark.skip_less_device_memory(80000)
Expand Down
4 changes: 4 additions & 0 deletions tests/integration/test_lists/qa/llm_function_core.txt
Original file line number Diff line number Diff line change
Expand Up @@ -389,14 +389,18 @@ accuracy/test_disaggregated_serving.py::TestDeepSeekV32Exp::test_auto_dtype_with
accuracy/test_disaggregated_serving.py::TestDeepSeekV32Exp::test_auto_dtype_with_helix[fifo-cudagraph:with_padding-pp1dp2cp2]
accuracy/test_disaggregated_serving.py::TestDeepSeekV32Exp::test_auto_dtype_with_helix[fifo-cudagraph:with_padding-pp2tp1cp2]
accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_nixl_backend
accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_gen_only_sync
accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_kv_cache_v2_nixl_python
accuracy/test_dwdp_disaggregated_serving.py::TestDwdpDeepSeekV3Lite::test_dwdp_accuracy
accuracy/test_disaggregated_serving.py::TestGemma3_1BInstruct::test_auto_dtype[False]
accuracy/test_disaggregated_serving.py::TestGemma3_1BInstruct::test_auto_dtype[True]
accuracy/test_disaggregated_serving.py::TestGemma3_1BInstruct::test_kv_cache_v2_nixl_python
accuracy/test_disaggregated_serving.py::TestGPTOSS::test_auto_dtype[True]
accuracy/test_disaggregated_serving.py::TestGPTOSS::test_auto_dtype[False]
accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_auto_dtype[False-False-False-True]
accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_auto_dtype[True-True-True-True]
accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_ngram
accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_kv_cache_v2_nixl_python
accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_eagle3[eagle3_one_model=True-overlap_scheduler=True]
accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_eagle3[eagle3_one_model=False-overlap_scheduler=False]
accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_guided_decoding[xgrammar]
Expand Down
3 changes: 3 additions & 0 deletions tests/integration/test_lists/test-db/l0_dgx_b300.yml
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,9 @@ l0_dgx_b300:
- disaggregated/test_disaggregated.py::test_disaggregated_benchmark_on_diff_backends[DeepSeek-V3-Lite-fp8]
- accuracy/test_disaggregated_serving.py::TestQwen3_8B::test_nixl_backend
- accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_nixl_backend
- accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_kv_cache_v2_nixl_python
- accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_kv_cache_v2_nixl_python
- accuracy/test_disaggregated_serving.py::TestGemma3_1BInstruct::test_kv_cache_v2_nixl_python
- accuracy/test_llm_api_pytorch.py::TestDeepSeekR1::test_nvfp4_multi_gpus[throughput_tp4] TIMEOUT (180)
- accuracy/test_llm_api_pytorch.py::TestDeepSeekR1::test_nvfp4_multi_gpus[throughput_pp4_mtp] TIMEOUT (180)
- condition:
Expand Down
4 changes: 4 additions & 0 deletions tests/integration/test_lists/test-db/l0_dgx_h100.yml
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,13 @@ l0_dgx_h100:
- accuracy/test_disaggregated_serving.py::TestQwen3_8B::test_chunked_prefill
- accuracy/test_disaggregated_serving.py::TestQwen3_8B::test_nixl_backend
- accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_nixl_backend
- accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_gen_only_sync
- accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_kv_cache_v2_nixl_python
- accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_ngram
- accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_kv_cache_v2_nixl_python
- accuracy/test_disaggregated_serving.py::TestGemma3_1BInstruct::test_auto_dtype[False]
- accuracy/test_disaggregated_serving.py::TestGemma3_1BInstruct::test_auto_dtype[True]
- accuracy/test_disaggregated_serving.py::TestGemma3_1BInstruct::test_kv_cache_v2_nixl_python
- accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_auto_dtype[False-False-False-False]
- accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_auto_dtype[False-False-False-True]
- accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_auto_dtype[False-False-True-False]
Expand Down
6 changes: 5 additions & 1 deletion tests/unittest/_torch/executor/test_kv_cache_v2_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,14 +99,18 @@ def make_encoder_request(request_id, encoder_output_len, lora_task_id=None):
return req


def make_disagg_request(request_id):
def make_disagg_request(request_id, context_remaining_length=1, num_draft_tokens=0):
req = Mock()
req.request_id = request_id
req.py_request_id = request_id
req.state_value = DISAGG_GEN_INIT
req.is_context_init_state = False
req.is_generation_in_progress_state = False
req.is_first_context_chunk = True
req.context_remaining_length = context_remaining_length
req.num_draft_tokens = num_draft_tokens
req.has_draft_tokens = num_draft_tokens > 0
req.py_draft_tokens = [0] * num_draft_tokens if num_draft_tokens > 0 else []
return req


Expand Down
Loading
Loading