Skip to content

Commit cc7ec93

Browse files
committed
fix: snapshot past sequence length to prevent CUDA index out of bounds error
- Added a snapshot of the past sequence length before the layer loop to ensure consistent indexing during decoder layer calls. - This change addresses a potential "device-side assert: index out of bounds" error by preventing incorrect indexing of key_layer with attention_mask indices.
1 parent 137911f commit cc7ec93

1 file changed

Lines changed: 13 additions & 1 deletion

File tree

src/rabbitllm/engine/base.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -990,6 +990,18 @@ def _load_cpu_or_cache(name: str):
990990
# back those dead pool blocks so every subsequent layer load has room.
991991
clean_memory()
992992

993+
# Snapshot the past sequence length ONCE before the layer loop.
994+
# DiskKVCache._seq_len is updated inside update() on every decoder layer call,
995+
# so querying it inside the loop would give len_p+1, len_p+2, ... for successive
996+
# layers — causing Flash Attention's _upad_input to try indexing key_layer
997+
# (size past+1) with attention_mask indices up to past+N, triggering the
998+
# "device-side assert: index out of bounds" CUDA error on decode step 2+.
999+
_past_seq_len_snapshot = (
1000+
self.get_past_key_values_cache_seq_len(past_key_values)
1001+
if past_key_values is not None
1002+
else 0
1003+
)
1004+
9931005
layer_iter = enumerate(zip(self.layer_names, self.layers))
9941006
if getattr(self, "show_layer_progress", True):
9951007
layer_iter = tqdm(
@@ -1169,7 +1181,7 @@ def _load_cpu_or_cache(name: str):
11691181
self._fix_layer_attention_head_dim(layer)
11701182
if past_key_values is not None:
11711183
k_cache, v_cache = self._get_layer_past_kv(past_key_values, i - 1)
1172-
len_p = self.get_past_key_values_cache_seq_len(past_key_values)
1184+
len_p = _past_seq_len_snapshot
11731185
len_s = self.get_sequence_len(seq)
11741186
position_ids_args = self.get_position_ids_args(
11751187
position_ids, len_p, len_s

0 commit comments

Comments
 (0)