diff --git a/tensorrt_llm/_torch/disaggregation/transceiver.py b/tensorrt_llm/_torch/disaggregation/transceiver.py index 73c75dd82759..32d21ced085f 100644 --- a/tensorrt_llm/_torch/disaggregation/transceiver.py +++ b/tensorrt_llm/_torch/disaggregation/transceiver.py @@ -310,8 +310,37 @@ def respond_and_send_async(self, req: LlmRequest): ) self._send_reqs[rid] = req + @nvtx_range("KvCacheTransceiverV2.request_and_receive_sync") def request_and_receive_sync(self, req: LlmRequest): - raise NotImplementedError("request_and_receive_sync is not implemented") + rid = get_unique_rid(req) + if rid in self._recv_sessions: + logger.warning( + f"request_and_receive_sync: rid={rid} already has a recv session, skipping" + ) + return + req.state = LlmRequestState.DISAGG_GENERATION_TRANS_IN_PROGRESS + session = None + try: + session = self._transfer_worker.create_rx_session(req) + self._recv_sessions[rid] = session + self._recv_reqs[rid] = req + session.receive(self._create_kv_slice(req)) + result = session.wait_complete(blocking=True) + + if result == WaitResult.COMPLETED: + if self._need_aux_transfer(req): + self._apply_aux(session, req) + req.state = LlmRequestState.DISAGG_GENERATION_TRANS_COMPLETE + else: + req.state = LlmRequestState.DISAGG_TRANS_ERROR + except Exception: + req.state = LlmRequestState.DISAGG_TRANS_ERROR + raise + finally: + if session is not None: + session.close() + self._recv_sessions.pop(rid, None) + self._recv_reqs.pop(rid, None) @nvtx_range("KvCacheTransceiverV2.request_and_receive_async") def request_and_receive_async(self, req: LlmRequest): diff --git a/tensorrt_llm/_torch/pyexecutor/scheduler/scheduler_v2.py b/tensorrt_llm/_torch/pyexecutor/scheduler/scheduler_v2.py index 403138deed4b..7e55b33d341f 100644 --- a/tensorrt_llm/_torch/pyexecutor/scheduler/scheduler_v2.py +++ b/tensorrt_llm/_torch/pyexecutor/scheduler/scheduler_v2.py @@ -230,8 +230,19 @@ def _schedule_loop(self, active_requests, inflight_request_ids): req_state_value = req.state_value - # Disagg gen init bypasses normal state gating (same as C++ / V1 scheduler) + # Disagg gen init bypasses normal state gating (same as C++ / V1 scheduler), + # but the V2 scheduler owns inline KV allocation so we must allocate here. + # V1 defers allocation to prepare_resources; V2 prepare_resources is a no-op + # for the primary manager, so allocation must happen in the scheduling loop. if req_state_value == self._disagg_gen_init_state_value: + if not self.kv_cache_manager.prepare_context(req): + req_it += 1 + continue + if not self.kv_cache_manager.resize_context( + req, req.context_remaining_length + get_draft_token_length(req) + ): + req_it += 1 + continue disagg_candidates.append(req) req_it += 1 continue diff --git a/tests/integration/defs/accuracy/test_disaggregated_serving.py b/tests/integration/defs/accuracy/test_disaggregated_serving.py index 2d148a9d866b..2e12cccd3a27 100644 --- a/tests/integration/defs/accuracy/test_disaggregated_serving.py +++ b/tests/integration/defs/accuracy/test_disaggregated_serving.py @@ -633,6 +633,50 @@ def test_auto_dtype(self, ctx_disable_overlap_scheduler, self.MODEL_PATH) as llm: run_accuracy_test(llm, self.MODEL_NAME, ["MMLU", "GSM8K"]) + @skip_pre_hopper + @pytest.mark.skip_less_device(2) + def test_kv_cache_v2_nixl_python(self): + """Test with use_kv_cache_manager_v2=True, block_reuse=False, backend=NIXL, transceiver_runtime=PYTHON.""" + ctx_server_config = { + "disable_overlap_scheduler": True, + "kv_cache_config": { + "enable_block_reuse": False, + "use_kv_cache_manager_v2": True + }, + "cache_transceiver_config": { + "backend": "NIXL", + "transceiver_runtime": "PYTHON" + } + } + gen_server_config = { + "disable_overlap_scheduler": False, + "kv_cache_config": { + "enable_block_reuse": False, + "use_kv_cache_manager_v2": True + }, + "cache_transceiver_config": { + "backend": "NIXL", + "transceiver_runtime": "PYTHON" + } + } + disaggregated_server_config = { + "hostname": "localhost", + "port": 8000, + "backend": "pytorch", + "context_servers": { + "num_instances": 1, + "urls": ["localhost:8001"] + }, + "generation_servers": { + "num_instances": 1, + "urls": ["localhost:8002"] + } + } + with launch_disaggregated_llm(disaggregated_server_config, + ctx_server_config, gen_server_config, + self.MODEL_PATH) as llm: + run_accuracy_test(llm, self.MODEL_NAME, ["GSM8K"]) + @pytest.mark.skip_less_device(2) def test_ngram(self): speculative_decoding_config = { @@ -952,6 +996,59 @@ def test_nixl_backend(self): self.MODEL_PATH) as llm: run_accuracy_test(llm, self.MODEL_NAME, ["MMLU", "GSM8K"]) + @pytest.mark.skip_less_device(2) + @pytest.mark.skip_less_device_memory(60000) + @skip_no_hopper + def test_gen_only_sync(self): + """Test gen-only synchronous KV transfer path with NIXL Python transceiver. + + Sets TRTLLM_DISABLE_KV_CACHE_TRANSFER_OVERLAP=1 so the gen worker calls + request_and_receive_sync instead of the async path, mirroring the + gen-only benchmark mode used for disagg serving performance measurement. + TLLM_BENCHMARK_REQ_QUEUES_SIZE pre-saturates the gen queue with N requests + before the first forward pass (one-time warmup), then processing continues + normally. Accuracy must be identical to the standard async path. + """ + ctx_server_config = { + "disable_overlap_scheduler": True, + "cache_transceiver_config": { + "backend": "NIXL", + "transceiver_runtime": "PYTHON", + "max_tokens_in_buffer": 4096, + }, + } + gen_server_config = { + "disable_overlap_scheduler": True, + "cache_transceiver_config": { + "backend": "NIXL", + "transceiver_runtime": "PYTHON", + "max_tokens_in_buffer": 4096, + }, + } + disaggregated_server_config = { + "hostname": "localhost", + "backend": "pytorch", + "context_servers": { + "num_instances": 1 + }, + "generation_servers": { + "num_instances": 1 + }, + } + extra_env = { + # Use synchronous receive: request_and_receive_sync instead of async. + "TRTLLM_DISABLE_KV_CACHE_TRANSFER_OVERLAP": "1", + # Pre-saturate the gen queue with 4 requests before the first + # forward pass (matches gen-only benchmark setup). + "TLLM_BENCHMARK_REQ_QUEUES_SIZE": "4", + } + with launch_disaggregated_llm(disaggregated_server_config, + ctx_server_config, + gen_server_config, + self.MODEL_PATH, + extra_env=extra_env) as llm: + run_accuracy_test(llm, self.MODEL_NAME, ["GSM8K"]) + @pytest.mark.skip_less_device(8) @parametrize_with_ids("overlap_scheduler", [True, False]) @parametrize_with_ids("mtp_nextn", [0, 2]) @@ -1141,6 +1238,51 @@ def test_guided_decoding(self, backend: str, mtp_nextn: int, mocker): self.MODEL_PATH) as llm: run_accuracy_test(llm, self.MODEL_NAME, ["JsonModeEval"]) + @pytest.mark.skip_less_device(2) + @pytest.mark.skip_less_device_memory(60000) + @skip_pre_hopper + def test_kv_cache_v2_nixl_python(self): + """Test with use_kv_cache_manager_v2=True, block_reuse=False, backend=NIXL, transceiver_runtime=PYTHON.""" + ctx_server_config = { + "disable_overlap_scheduler": True, + "kv_cache_config": { + "enable_block_reuse": False, + "use_kv_cache_manager_v2": True + }, + "cache_transceiver_config": { + "backend": "NIXL", + "transceiver_runtime": "PYTHON" + } + } + gen_server_config = { + "disable_overlap_scheduler": True, + "kv_cache_config": { + "enable_block_reuse": False, + "use_kv_cache_manager_v2": True + }, + "cache_transceiver_config": { + "backend": "NIXL", + "transceiver_runtime": "PYTHON" + } + } + disaggregated_server_config = { + "hostname": "localhost", + "port": 8000, + "backend": "pytorch", + "context_servers": { + "num_instances": 1, + "urls": ["localhost:8001"] + }, + "generation_servers": { + "num_instances": 1, + "urls": ["localhost:8002"] + } + } + with launch_disaggregated_llm(disaggregated_server_config, + ctx_server_config, gen_server_config, + self.MODEL_PATH) as llm: + run_accuracy_test(llm, self.MODEL_NAME, ["GSM8K"]) + @pytest.mark.timeout(DEFAULT_TEST_TIMEOUT) class TestGemma3_1BInstruct(LlmapiAccuracyTestHarness): @@ -1193,6 +1335,52 @@ def test_auto_dtype(self, block_reuse): self.MODEL_PATH) as llm: run_accuracy_test(llm, self.MODEL_NAME, ["MMLU", "GSM8K"]) + @pytest.mark.skip_less_device(2) + @skip_pre_hopper + def test_kv_cache_v2_nixl_python(self): + """Test with use_kv_cache_manager_v2=True, block_reuse=False, backend=NIXL, transceiver_runtime=PYTHON.""" + ctx_server_config = { + "disable_overlap_scheduler": True, + "cuda_graph_config": None, + "kv_cache_config": { + "enable_block_reuse": False, + "use_kv_cache_manager_v2": True + }, + "cache_transceiver_config": { + "backend": "NIXL", + "transceiver_runtime": "PYTHON" + } + } + gen_server_config = { + "disable_overlap_scheduler": True, + "cuda_graph_config": None, + "kv_cache_config": { + "enable_block_reuse": False, + "use_kv_cache_manager_v2": True + }, + "cache_transceiver_config": { + "backend": "NIXL", + "transceiver_runtime": "PYTHON" + } + } + disaggregated_server_config = { + "hostname": "localhost", + "port": 8000, + "backend": "pytorch", + "context_servers": { + "num_instances": 1, + "urls": ["localhost:8001"] + }, + "generation_servers": { + "num_instances": 1, + "urls": ["localhost:8002"] + } + } + with launch_disaggregated_llm(disaggregated_server_config, + ctx_server_config, gen_server_config, + self.MODEL_PATH) as llm: + run_accuracy_test(llm, self.MODEL_NAME, ["MMLU", "GSM8K"]) + @skip_pre_blackwell @pytest.mark.skip_less_device_memory(80000) diff --git a/tests/integration/test_lists/qa/llm_function_core.txt b/tests/integration/test_lists/qa/llm_function_core.txt index a49644bbb479..e936d9423495 100644 --- a/tests/integration/test_lists/qa/llm_function_core.txt +++ b/tests/integration/test_lists/qa/llm_function_core.txt @@ -389,14 +389,18 @@ accuracy/test_disaggregated_serving.py::TestDeepSeekV32Exp::test_auto_dtype_with accuracy/test_disaggregated_serving.py::TestDeepSeekV32Exp::test_auto_dtype_with_helix[fifo-cudagraph:with_padding-pp1dp2cp2] accuracy/test_disaggregated_serving.py::TestDeepSeekV32Exp::test_auto_dtype_with_helix[fifo-cudagraph:with_padding-pp2tp1cp2] accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_nixl_backend +accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_gen_only_sync +accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_kv_cache_v2_nixl_python accuracy/test_dwdp_disaggregated_serving.py::TestDwdpDeepSeekV3Lite::test_dwdp_accuracy accuracy/test_disaggregated_serving.py::TestGemma3_1BInstruct::test_auto_dtype[False] accuracy/test_disaggregated_serving.py::TestGemma3_1BInstruct::test_auto_dtype[True] +accuracy/test_disaggregated_serving.py::TestGemma3_1BInstruct::test_kv_cache_v2_nixl_python accuracy/test_disaggregated_serving.py::TestGPTOSS::test_auto_dtype[True] accuracy/test_disaggregated_serving.py::TestGPTOSS::test_auto_dtype[False] accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_auto_dtype[False-False-False-True] accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_auto_dtype[True-True-True-True] accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_ngram +accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_kv_cache_v2_nixl_python accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_eagle3[eagle3_one_model=True-overlap_scheduler=True] accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_eagle3[eagle3_one_model=False-overlap_scheduler=False] accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_guided_decoding[xgrammar] diff --git a/tests/integration/test_lists/test-db/l0_dgx_b300.yml b/tests/integration/test_lists/test-db/l0_dgx_b300.yml index 1c96f99068cb..c7db16511701 100644 --- a/tests/integration/test_lists/test-db/l0_dgx_b300.yml +++ b/tests/integration/test_lists/test-db/l0_dgx_b300.yml @@ -92,6 +92,9 @@ l0_dgx_b300: - disaggregated/test_disaggregated.py::test_disaggregated_benchmark_on_diff_backends[DeepSeek-V3-Lite-fp8] - accuracy/test_disaggregated_serving.py::TestQwen3_8B::test_nixl_backend - accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_nixl_backend + - accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_kv_cache_v2_nixl_python + - accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_kv_cache_v2_nixl_python + - accuracy/test_disaggregated_serving.py::TestGemma3_1BInstruct::test_kv_cache_v2_nixl_python - accuracy/test_llm_api_pytorch.py::TestDeepSeekR1::test_nvfp4_multi_gpus[throughput_tp4] TIMEOUT (180) - accuracy/test_llm_api_pytorch.py::TestDeepSeekR1::test_nvfp4_multi_gpus[throughput_pp4_mtp] TIMEOUT (180) - condition: diff --git a/tests/integration/test_lists/test-db/l0_dgx_h100.yml b/tests/integration/test_lists/test-db/l0_dgx_h100.yml index ea58ff740384..6016b5bd6612 100644 --- a/tests/integration/test_lists/test-db/l0_dgx_h100.yml +++ b/tests/integration/test_lists/test-db/l0_dgx_h100.yml @@ -31,9 +31,13 @@ l0_dgx_h100: - accuracy/test_disaggregated_serving.py::TestQwen3_8B::test_chunked_prefill - accuracy/test_disaggregated_serving.py::TestQwen3_8B::test_nixl_backend - accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_nixl_backend + - accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_gen_only_sync + - accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_kv_cache_v2_nixl_python - accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_ngram + - accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_kv_cache_v2_nixl_python - accuracy/test_disaggregated_serving.py::TestGemma3_1BInstruct::test_auto_dtype[False] - accuracy/test_disaggregated_serving.py::TestGemma3_1BInstruct::test_auto_dtype[True] + - accuracy/test_disaggregated_serving.py::TestGemma3_1BInstruct::test_kv_cache_v2_nixl_python - accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_auto_dtype[False-False-False-False] - accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_auto_dtype[False-False-False-True] - accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_auto_dtype[False-False-True-False] diff --git a/tests/unittest/_torch/executor/test_kv_cache_v2_scheduler.py b/tests/unittest/_torch/executor/test_kv_cache_v2_scheduler.py index a2488b8a1474..2837f91dee31 100644 --- a/tests/unittest/_torch/executor/test_kv_cache_v2_scheduler.py +++ b/tests/unittest/_torch/executor/test_kv_cache_v2_scheduler.py @@ -99,7 +99,7 @@ def make_encoder_request(request_id, encoder_output_len, lora_task_id=None): return req -def make_disagg_request(request_id): +def make_disagg_request(request_id, context_remaining_length=1, num_draft_tokens=0): req = Mock() req.request_id = request_id req.py_request_id = request_id @@ -107,6 +107,10 @@ def make_disagg_request(request_id): req.is_context_init_state = False req.is_generation_in_progress_state = False req.is_first_context_chunk = True + req.context_remaining_length = context_remaining_length + req.num_draft_tokens = num_draft_tokens + req.has_draft_tokens = num_draft_tokens > 0 + req.py_draft_tokens = [0] * num_draft_tokens if num_draft_tokens > 0 else [] return req diff --git a/tests/unittest/disaggregated/test_py_cache_transceiver_mp.py b/tests/unittest/disaggregated/test_py_cache_transceiver_mp.py index 730c5589099e..68081984d277 100644 --- a/tests/unittest/disaggregated/test_py_cache_transceiver_mp.py +++ b/tests/unittest/disaggregated/test_py_cache_transceiver_mp.py @@ -700,6 +700,10 @@ def gather_and_verify_request( _run_gen_first1_transfer(rank, is_ctx, transceiver, my_requests) elif ctx_gen_workflow == "gen_first2": _run_gen_first2_transfer(rank, is_ctx, transceiver, my_requests) + elif ctx_gen_workflow == "ctx_first_sync": + _run_ctx_first_sync_transfer( + rank, is_ctx, transceiver, my_requests, ctx_enable_dp, gen_enable_dp + ) else: _run_ctx_first_transfer( rank, is_ctx, transceiver, my_requests, ctx_enable_dp, gen_enable_dp @@ -886,6 +890,55 @@ def _wait_ctx_request_ready(transceiver, my_requests): return all_ready +def _run_ctx_first_sync_transfer( + rank, is_ctx, transceiver, my_requests, ctx_enable_dp, gen_enable_dp +): + """Context-first transfer using synchronous receive (request_and_receive_sync).""" + do_warmup = not ctx_enable_dp and not gen_enable_dp and len(my_requests) > 0 + if do_warmup: + warmup_idx, warmup_request = my_requests[0] + remaining_requests = my_requests[1:] + + if is_ctx: + print(f"[Rank {rank}] CTX: Submitting warmup request {warmup_idx}...", flush=True) + transceiver.respond_and_send_async(warmup_request) + + print(f"[Rank {rank}] Before warmup barrier", flush=True) + dist.barrier() + print(f"[Rank {rank}] After warmup barrier", flush=True) + + if not is_ctx: + print(f"[Rank {rank}] GEN: Sync-receiving warmup request {warmup_idx}...", flush=True) + transceiver.request_and_receive_sync(warmup_request) + print(f"[Rank {rank}] GEN: Warmup completed (sync)", flush=True) + + if is_ctx: + transceiver.check_context_transfer_status(None) + print(f"[Rank {rank}] CTX: Warmup completed", flush=True) + + print(f"[Rank {rank}] Before post-warmup barrier", flush=True) + dist.barrier() + print(f"[Rank {rank}] After post-warmup barrier", flush=True) + else: + remaining_requests = my_requests + + if is_ctx: + for req_idx, request in remaining_requests: + print(f"[Rank {rank}] CTX: Submitting request {req_idx}...", flush=True) + transceiver.respond_and_send_async(request) + print(f"[Rank {rank}] CTX: Submitted {len(remaining_requests)} send requests", flush=True) + + print(f"[Rank {rank}] Before phase2 barrier", flush=True) + dist.barrier() + print(f"[Rank {rank}] After phase2 barrier", flush=True) + + if not is_ctx: + for req_idx, request in remaining_requests: + print(f"[Rank {rank}] GEN: Sync-receiving request {req_idx}...", flush=True) + transceiver.request_and_receive_sync(request) + print(f"[Rank {rank}] GEN: Sync-received {len(remaining_requests)} requests", flush=True) + + def _run_gen_first1_transfer(rank, is_ctx, transceiver, my_requests): """Generation-first transfer: ctx prepares first, then gen receives and ctx sends.""" # Step 1: Context side calls prepare_context_requests, no kvcache request is sent, thus no request @@ -1073,7 +1126,10 @@ def run_v2_transceiver_mp( [(c[0], c[1], c[2], c[3], c[4], c[5], c[6]) for c in MP_TEST_CONFIGS], ids=[c[7] for c in MP_TEST_CONFIGS], ) -@pytest.mark.parametrize("workflow", ["ctx_first", "gen_first1", "gen_first2"]) +@pytest.mark.parametrize( + "workflow", + ["ctx_first", "ctx_first_sync", "gen_first1", "gen_first2"], +) def test_v2_transceiver_mp( ctx_tp, ctx_pp, gen_tp, gen_pp, ctx_enable_dp, gen_enable_dp, is_mla, workflow ):