Skip to content

Commit a2b012d

Browse files
committed
Gen-only sync KV transfer for dis-agg
Signed-off-by: Shixiaowei02 <39303645+Shixiaowei02@users.noreply.github.com>
1 parent ce71620 commit a2b012d

5 files changed

Lines changed: 223 additions & 2 deletions

File tree

tensorrt_llm/_torch/disaggregation/transceiver.py

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -310,8 +310,31 @@ def respond_and_send_async(self, req: LlmRequest):
310310
)
311311
self._send_reqs[rid] = req
312312

313+
@nvtx_range("KvCacheTransceiverV2.request_and_receive_sync")
313314
def request_and_receive_sync(self, req: LlmRequest):
314-
raise NotImplementedError("request_and_receive_sync is not implemented")
315+
rid = get_unique_rid(req)
316+
if rid in self._recv_sessions:
317+
logger.warning(
318+
f"request_and_receive_sync: rid={rid} already has a recv session, skipping"
319+
)
320+
return
321+
req.state = LlmRequestState.DISAGG_GENERATION_TRANS_IN_PROGRESS
322+
session = self._transfer_worker.create_rx_session(req)
323+
self._recv_sessions[rid] = session
324+
self._recv_reqs[rid] = req
325+
session.receive(self._create_kv_slice(req))
326+
result = session.wait_complete(blocking=True)
327+
328+
if result == WaitResult.COMPLETED:
329+
if self._need_aux_transfer(req):
330+
self._apply_aux(session, req)
331+
req.state = LlmRequestState.DISAGG_GENERATION_TRANS_COMPLETE
332+
else:
333+
req.state = LlmRequestState.DISAGG_TRANS_ERROR
334+
335+
session.close()
336+
del self._recv_sessions[rid]
337+
del self._recv_reqs[rid]
315338

316339
@nvtx_range("KvCacheTransceiverV2.request_and_receive_async")
317340
def request_and_receive_async(self, req: LlmRequest):

tests/integration/defs/accuracy/test_disaggregated_serving.py

Lines changed: 139 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -956,6 +956,145 @@ def test_nixl_backend(self):
956956
self.MODEL_PATH) as llm:
957957
run_accuracy_test(llm, self.MODEL_NAME, ["MMLU", "GSM8K"])
958958

959+
@pytest.mark.skip_less_device(2)
960+
@pytest.mark.skip_less_device_memory(60000)
961+
@skip_no_hopper
962+
def test_gen_only_sync(self):
963+
"""Test gen-only synchronous KV transfer path with NIXL Python transceiver.
964+
965+
Sets TRTLLM_DISABLE_KV_CACHE_TRANSFER_OVERLAP=1 so the gen worker calls
966+
request_and_receive_sync instead of the async path, mirroring the
967+
gen-only benchmark mode used for disagg serving performance measurement.
968+
TLLM_BENCHMARK_REQ_QUEUES_SIZE pre-saturates the gen queue with N requests
969+
before the first forward pass (one-time warmup), then processing continues
970+
normally. Accuracy must be identical to the standard async path.
971+
"""
972+
ctx_server_config = {
973+
"disable_overlap_scheduler": True,
974+
"cache_transceiver_config": {
975+
"backend": "NIXL",
976+
"transceiver_runtime": "PYTHON",
977+
"max_tokens_in_buffer": 4096,
978+
},
979+
}
980+
gen_server_config = {
981+
"disable_overlap_scheduler": True,
982+
"cache_transceiver_config": {
983+
"backend": "NIXL",
984+
"transceiver_runtime": "PYTHON",
985+
"max_tokens_in_buffer": 4096,
986+
},
987+
}
988+
disaggregated_server_config = {
989+
"hostname": "localhost",
990+
"backend": "pytorch",
991+
"context_servers": {
992+
"num_instances": 1
993+
},
994+
"generation_servers": {
995+
"num_instances": 1
996+
},
997+
}
998+
extra_env = {
999+
# Use synchronous receive: request_and_receive_sync instead of async.
1000+
"TRTLLM_DISABLE_KV_CACHE_TRANSFER_OVERLAP": "1",
1001+
# Pre-saturate the gen queue with 4 requests before the first
1002+
# forward pass (matches gen-only benchmark setup).
1003+
"TLLM_BENCHMARK_REQ_QUEUES_SIZE": "4",
1004+
}
1005+
with launch_disaggregated_llm(disaggregated_server_config,
1006+
ctx_server_config,
1007+
gen_server_config,
1008+
self.MODEL_PATH,
1009+
extra_env=extra_env) as llm:
1010+
run_accuracy_test(llm, self.MODEL_NAME, ["GSM8K"])
1011+
1012+
@pytest.mark.skip_less_device(2)
1013+
@pytest.mark.skip_less_device_memory(60000)
1014+
@skip_no_hopper
1015+
def test_kv_cache_manager_v2(self):
1016+
"""Test disaggregated serving with KVCacheManagerV2 and NIXL Python transceiver."""
1017+
ctx_server_config = {
1018+
"disable_overlap_scheduler": True,
1019+
"kv_cache_config": {
1020+
"use_kv_cache_manager_v2": True,
1021+
},
1022+
"cache_transceiver_config": {
1023+
"backend": "NIXL",
1024+
"transceiver_runtime": "PYTHON",
1025+
"max_tokens_in_buffer": 4096,
1026+
},
1027+
}
1028+
gen_server_config = {
1029+
"disable_overlap_scheduler": True,
1030+
"kv_cache_config": {
1031+
"use_kv_cache_manager_v2": True,
1032+
},
1033+
"cache_transceiver_config": {
1034+
"backend": "NIXL",
1035+
"transceiver_runtime": "PYTHON",
1036+
"max_tokens_in_buffer": 4096,
1037+
},
1038+
}
1039+
disaggregated_server_config = {
1040+
"hostname": "localhost",
1041+
"backend": "pytorch",
1042+
"context_servers": {
1043+
"num_instances": 1
1044+
},
1045+
"generation_servers": {
1046+
"num_instances": 1
1047+
},
1048+
}
1049+
with launch_disaggregated_llm(disaggregated_server_config,
1050+
ctx_server_config, gen_server_config,
1051+
self.MODEL_PATH) as llm:
1052+
run_accuracy_test(llm, self.MODEL_NAME, ["GSM8K"])
1053+
1054+
@pytest.mark.skip_less_device(8)
1055+
@skip_no_hopper
1056+
def test_kv_cache_manager_v2_ctx_tp2pp2_gen_tp4(self):
1057+
"""Test KVCacheManagerV2 with asymmetric ctx/gen topology: ctx=tp2pp2, gen=tp4."""
1058+
ctx_server_config = {
1059+
"disable_overlap_scheduler": True,
1060+
"tensor_parallel_size": 2,
1061+
"pipeline_parallel_size": 2,
1062+
"kv_cache_config": {
1063+
"use_kv_cache_manager_v2": True,
1064+
},
1065+
"cache_transceiver_config": {
1066+
"backend": "NIXL",
1067+
"transceiver_runtime": "PYTHON",
1068+
"max_tokens_in_buffer": 4096,
1069+
},
1070+
}
1071+
gen_server_config = {
1072+
"disable_overlap_scheduler": True,
1073+
"tensor_parallel_size": 4,
1074+
"kv_cache_config": {
1075+
"use_kv_cache_manager_v2": True,
1076+
},
1077+
"cache_transceiver_config": {
1078+
"backend": "NIXL",
1079+
"transceiver_runtime": "PYTHON",
1080+
"max_tokens_in_buffer": 4096,
1081+
},
1082+
}
1083+
disaggregated_server_config = {
1084+
"hostname": "localhost",
1085+
"backend": "pytorch",
1086+
"context_servers": {
1087+
"num_instances": 1
1088+
},
1089+
"generation_servers": {
1090+
"num_instances": 1
1091+
},
1092+
}
1093+
with launch_disaggregated_llm(disaggregated_server_config,
1094+
ctx_server_config, gen_server_config,
1095+
self.MODEL_PATH) as llm:
1096+
run_accuracy_test(llm, self.MODEL_NAME, ["GSM8K"])
1097+
9591098
@pytest.mark.skip_less_device(8)
9601099
@parametrize_with_ids("overlap_scheduler", [True, False])
9611100
@parametrize_with_ids("mtp_nextn", [0, 2])

tests/integration/test_lists/test-db/l0_dgx_h100.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,8 @@ l0_dgx_h100:
3131
- accuracy/test_disaggregated_serving.py::TestQwen3_8B::test_chunked_prefill
3232
- accuracy/test_disaggregated_serving.py::TestQwen3_8B::test_nixl_backend
3333
- accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_nixl_backend
34+
- accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_gen_only_sync
35+
- accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_kv_cache_manager_v2
3436
- accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_ngram
3537
- accuracy/test_disaggregated_serving.py::TestGemma3_1BInstruct::test_auto_dtype[False]
3638
- accuracy/test_disaggregated_serving.py::TestGemma3_1BInstruct::test_auto_dtype[True]

tests/integration/test_lists/test-db/l0_dgx_h200.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ l0_dgx_h200:
2828
- accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype[mtp_nextn=0-overlap_scheduler=False]
2929
- accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype[mtp_nextn=2-overlap_scheduler=True]
3030
- accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype[mtp_nextn=2-overlap_scheduler=False]
31+
- accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_kv_cache_manager_v2_ctx_tp2pp2_gen_tp4
3132
- accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_tp_pp_symmetric[GSM8K-tp2pp2]
3233
- accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_tp_pp_symmetric[MMLU-tp2pp2]
3334
- accuracy/test_disaggregated_serving.py::TestLlama3_1_8BInstruct::test_ctx_pp_gen_tp_asymmetric[GSM8K-gen_tp=1-ctx_pp=4]

tests/unittest/disaggregated/test_py_cache_transceiver_mp.py

Lines changed: 57 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -700,6 +700,10 @@ def gather_and_verify_request(
700700
_run_gen_first1_transfer(rank, is_ctx, transceiver, my_requests)
701701
elif ctx_gen_workflow == "gen_first2":
702702
_run_gen_first2_transfer(rank, is_ctx, transceiver, my_requests)
703+
elif ctx_gen_workflow == "ctx_first_sync":
704+
_run_ctx_first_sync_transfer(
705+
rank, is_ctx, transceiver, my_requests, ctx_enable_dp, gen_enable_dp
706+
)
703707
else:
704708
_run_ctx_first_transfer(
705709
rank, is_ctx, transceiver, my_requests, ctx_enable_dp, gen_enable_dp
@@ -886,6 +890,55 @@ def _wait_ctx_request_ready(transceiver, my_requests):
886890
return all_ready
887891

888892

893+
def _run_ctx_first_sync_transfer(
894+
rank, is_ctx, transceiver, my_requests, ctx_enable_dp, gen_enable_dp
895+
):
896+
"""Context-first transfer using synchronous receive (request_and_receive_sync)."""
897+
do_warmup = not ctx_enable_dp and not gen_enable_dp and len(my_requests) > 0
898+
if do_warmup:
899+
warmup_idx, warmup_request = my_requests[0]
900+
remaining_requests = my_requests[1:]
901+
902+
if is_ctx:
903+
print(f"[Rank {rank}] CTX: Submitting warmup request {warmup_idx}...", flush=True)
904+
transceiver.respond_and_send_async(warmup_request)
905+
906+
print(f"[Rank {rank}] Before warmup barrier", flush=True)
907+
dist.barrier()
908+
print(f"[Rank {rank}] After warmup barrier", flush=True)
909+
910+
if not is_ctx:
911+
print(f"[Rank {rank}] GEN: Sync-receiving warmup request {warmup_idx}...", flush=True)
912+
transceiver.request_and_receive_sync(warmup_request)
913+
print(f"[Rank {rank}] GEN: Warmup completed (sync)", flush=True)
914+
915+
if is_ctx:
916+
transceiver.check_context_transfer_status(None)
917+
print(f"[Rank {rank}] CTX: Warmup completed", flush=True)
918+
919+
print(f"[Rank {rank}] Before post-warmup barrier", flush=True)
920+
dist.barrier()
921+
print(f"[Rank {rank}] After post-warmup barrier", flush=True)
922+
else:
923+
remaining_requests = my_requests
924+
925+
if is_ctx:
926+
for req_idx, request in remaining_requests:
927+
print(f"[Rank {rank}] CTX: Submitting request {req_idx}...", flush=True)
928+
transceiver.respond_and_send_async(request)
929+
print(f"[Rank {rank}] CTX: Submitted {len(remaining_requests)} send requests", flush=True)
930+
931+
print(f"[Rank {rank}] Before phase2 barrier", flush=True)
932+
dist.barrier()
933+
print(f"[Rank {rank}] After phase2 barrier", flush=True)
934+
935+
if not is_ctx:
936+
for req_idx, request in remaining_requests:
937+
print(f"[Rank {rank}] GEN: Sync-receiving request {req_idx}...", flush=True)
938+
transceiver.request_and_receive_sync(request)
939+
print(f"[Rank {rank}] GEN: Sync-received {len(remaining_requests)} requests", flush=True)
940+
941+
889942
def _run_gen_first1_transfer(rank, is_ctx, transceiver, my_requests):
890943
"""Generation-first transfer: ctx prepares first, then gen receives and ctx sends."""
891944
# Step 1: Context side calls prepare_context_requests, no kvcache request is sent, thus no request
@@ -1073,7 +1126,10 @@ def run_v2_transceiver_mp(
10731126
[(c[0], c[1], c[2], c[3], c[4], c[5], c[6]) for c in MP_TEST_CONFIGS],
10741127
ids=[c[7] for c in MP_TEST_CONFIGS],
10751128
)
1076-
@pytest.mark.parametrize("workflow", ["ctx_first", "gen_first1", "gen_first2"])
1129+
@pytest.mark.parametrize(
1130+
"workflow",
1131+
["ctx_first", "ctx_first_sync", "gen_first1", "gen_first2"],
1132+
)
10771133
def test_v2_transceiver_mp(
10781134
ctx_tp, ctx_pp, gen_tp, gen_pp, ctx_enable_dp, gen_enable_dp, is_mla, workflow
10791135
):

0 commit comments

Comments
 (0)