Skip to content

Commit ee80aee

Browse files
authored
[Model Runner V2] Minor cleanup for build_attn_metadata (vllm-project#29576)
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
1 parent 0aeb698 commit ee80aee

File tree

3 files changed

+13
-8
lines changed

3 files changed

+13
-8
lines changed

vllm/v1/worker/gpu/attn_utils.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@
1818
KVCacheConfig,
1919
KVCacheSpec,
2020
)
21-
from vllm.v1.utils import CpuGpuBuffer
2221
from vllm.v1.worker.utils import bind_kv_cache
2322

2423

@@ -145,17 +144,16 @@ def build_attn_metadata(
145144
attn_metadata_builders: list[AttentionMetadataBuilder],
146145
num_reqs: int,
147146
num_tokens: int,
148-
query_start_loc: CpuGpuBuffer,
147+
query_start_loc_gpu: torch.Tensor,
148+
query_start_loc_cpu: torch.Tensor,
149149
seq_lens: torch.Tensor,
150150
seq_lens_np: np.ndarray,
151151
num_computed_tokens_cpu: torch.Tensor | None,
152152
block_tables: Sequence[torch.Tensor],
153153
slot_mappings: torch.Tensor,
154154
kv_cache_config: KVCacheConfig,
155155
) -> dict[str, Any]:
156-
query_start_loc_gpu = query_start_loc.gpu[: num_reqs + 1]
157-
query_start_loc_cpu = query_start_loc.cpu[: num_reqs + 1]
158-
max_query_len = int(query_start_loc.np[: num_reqs + 1].max())
156+
max_query_len = int(query_start_loc_cpu.max())
159157
seq_lens = seq_lens[:num_reqs]
160158
seq_lens_cpu = torch.from_numpy(seq_lens_np)
161159
max_seq_len = int(seq_lens_np.max())

vllm/v1/worker/gpu/cudagraph_utils.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,8 @@ def capture_graph(
120120
attn_metadata_builders=attn_metadata_builders,
121121
num_reqs=batch_size,
122122
num_tokens=batch_size,
123-
query_start_loc=input_buffers.query_start_loc,
123+
query_start_loc_gpu=input_buffers.query_start_loc.gpu[: batch_size + 1],
124+
query_start_loc_cpu=input_buffers.query_start_loc.cpu[: batch_size + 1],
124125
seq_lens=input_buffers.seq_lens,
125126
seq_lens_np=np.full(batch_size, self.max_model_len, dtype=np.int32),
126127
num_computed_tokens_cpu=None, # FIXME

vllm/v1/worker/gpu/model_runner.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -226,11 +226,15 @@ def prepare_dummy_attn_metadata(self, input_batch: InputBatch) -> None:
226226
num_computed_tokens = torch.zeros(
227227
input_batch.num_reqs, dtype=torch.int32, device=self.device
228228
)
229+
query_start_loc = self.input_buffers.query_start_loc
230+
query_start_loc_gpu = query_start_loc.gpu[: input_batch.num_reqs + 1]
231+
query_start_loc_cpu = query_start_loc.cpu[: input_batch.num_reqs + 1]
229232
attn_metadata = build_attn_metadata(
230233
attn_metadata_builders=self.attn_metadata_builders,
231234
num_reqs=input_batch.num_reqs,
232235
num_tokens=input_batch.num_tokens,
233-
query_start_loc=self.input_buffers.query_start_loc,
236+
query_start_loc_gpu=query_start_loc_gpu,
237+
query_start_loc_cpu=query_start_loc_cpu,
234238
seq_lens=self.input_buffers.seq_lens,
235239
seq_lens_np=input_batch.seq_lens_np,
236240
num_computed_tokens_cpu=num_computed_tokens,
@@ -515,6 +519,7 @@ def prepare_inputs(
515519
self.input_buffers.query_start_loc.np[num_reqs + 1 :] = num_tokens
516520
self.input_buffers.query_start_loc.copy_to_gpu()
517521
query_start_loc_gpu = self.input_buffers.query_start_loc.gpu[: num_reqs + 1]
522+
query_start_loc_cpu = self.input_buffers.query_start_loc.cpu[: num_reqs + 1]
518523
query_start_loc_np = self.input_buffers.query_start_loc.np[: num_reqs + 1]
519524

520525
# Copy prefill tokens from CPU to GPU.
@@ -572,7 +577,8 @@ def prepare_inputs(
572577
attn_metadata_builders=self.attn_metadata_builders,
573578
num_reqs=num_reqs,
574579
num_tokens=num_tokens,
575-
query_start_loc=self.input_buffers.query_start_loc,
580+
query_start_loc_gpu=query_start_loc_gpu,
581+
query_start_loc_cpu=query_start_loc_cpu,
576582
seq_lens=self.input_buffers.seq_lens,
577583
seq_lens_np=seq_lens_np,
578584
num_computed_tokens_cpu=num_computed_tokens,

0 commit comments

Comments
 (0)