Skip to content

Commit 0aeb698

Browse files
authored
[Model Runner V2] Minor code cleanup (vllm-project#29570)
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
1 parent 9bb33c8 commit 0aeb698

File tree

3 files changed

+18
-18
lines changed

3 files changed

+18
-18
lines changed

vllm/v1/worker/gpu/cudagraph_utils.py

Lines changed: 2 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
from vllm.v1.kv_cache_interface import KVCacheConfig
1717
from vllm.v1.worker.gpu.attn_utils import build_attn_metadata
1818
from vllm.v1.worker.gpu.block_table import BlockTables
19+
from vllm.v1.worker.gpu.dp_utils import make_num_tokens_across_dp
1920
from vllm.v1.worker.gpu.input_batch import InputBuffers
2021

2122

@@ -127,15 +128,7 @@ def capture_graph(
127128
slot_mappings=slot_mappings,
128129
kv_cache_config=kv_cache_config,
129130
)
130-
if self.dp_size > 1:
131-
num_tokens_across_dp = torch.full(
132-
(self.dp_size,),
133-
batch_size,
134-
dtype=torch.int32,
135-
device="cpu",
136-
)
137-
else:
138-
num_tokens_across_dp = None
131+
num_tokens_across_dp = make_num_tokens_across_dp(self.dp_size, batch_size)
139132

140133
# Warm up.
141134
with set_forward_context(

vllm/v1/worker/gpu/dp_utils.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,3 +20,12 @@ def get_batch_metadata_across_dp(
2020
tensor[1][dp_rank] = cudagraph_size
2121
dist.all_reduce(tensor, group=group)
2222
return tensor[0], tensor[1]
23+
24+
25+
def make_num_tokens_across_dp(
26+
dp_size: int,
27+
num_tokens: int,
28+
) -> torch.Tensor | None:
29+
if dp_size == 1:
30+
return None
31+
return torch.full((dp_size,), num_tokens, dtype=torch.int32, device="cpu")

vllm/v1/worker/gpu/model_runner.py

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,10 @@
3535
)
3636
from vllm.v1.worker.gpu.block_table import BlockTables
3737
from vllm.v1.worker.gpu.cudagraph_utils import CudaGraphManager
38-
from vllm.v1.worker.gpu.dp_utils import get_batch_metadata_across_dp
38+
from vllm.v1.worker.gpu.dp_utils import (
39+
get_batch_metadata_across_dp,
40+
make_num_tokens_across_dp,
41+
)
3942
from vllm.v1.worker.gpu.input_batch import (
4043
InputBatch,
4144
InputBuffers,
@@ -255,12 +258,7 @@ def _dummy_run(
255258
if not skip_attn:
256259
self.prepare_dummy_attn_metadata(input_batch)
257260

258-
if self.dp_size == 1:
259-
num_tokens_across_dp: torch.Tensor | None = None
260-
else:
261-
num_tokens_across_dp = torch.full(
262-
(self.dp_size,), num_tokens, dtype=torch.int32, device="cpu"
263-
)
261+
num_tokens_across_dp = make_num_tokens_across_dp(self.dp_size, num_tokens)
264262
num_sampled_tokens = np.ones(input_batch.num_reqs, dtype=np.int32)
265263
with (
266264
self.maybe_dummy_run_with_lora(
@@ -816,7 +814,6 @@ def propose_draft(
816814
self.req_states.last_sampled_tokens,
817815
next_prefill_tokens,
818816
)
819-
self.req_states.draft_tokens[input_batch.idx_mapping] = draft_tokens
820817
return draft_tokens
821818

822819
def get_cudagraph_and_dp_padding(
@@ -1006,14 +1003,15 @@ def sample_tokens(
10061003
input_batch, sampler_output.sampled_token_ids, num_sampled, num_rejected
10071004
)
10081005
if self.do_spec_decode:
1009-
_ = self.propose_draft(
1006+
draft_tokens = self.propose_draft(
10101007
input_batch,
10111008
sampling_metadata,
10121009
hidden_states,
10131010
None, # aux_hidden_states
10141011
num_sampled,
10151012
num_rejected,
10161013
)
1014+
self.req_states.draft_tokens[input_batch.idx_mapping] = draft_tokens
10171015

10181016
if self.use_async_scheduling:
10191017
return async_output

0 commit comments

Comments
 (0)