Fix mRoPE position ID crash on Qwen2-VL prompt truncation#482
Fix mRoPE position ID crash on Qwen2-VL prompt truncation#482Mr-Neutr0n wants to merge 1 commit intomicrosoft:mainfrom
Conversation
When training Qwen2.5-VL with agent-lightning + verl, prompt truncation changes the token count but image_grid_thw is computed from the original (untruncated) image_urls. This causes get_rope_index to fail with a shape mismatch because it finds fewer image tokens in the truncated input_ids than entries in image_grid_thw. After prompt truncation, count remaining image regions in the truncated token sequence and slice image_urls to match before computing image_grid_thw, ensuring consistency between the token content and the mRoPE spatial metadata. Fixes microsoft#441
bdd1c8d to
ca0be5a
Compare
|
Friendly bump! Let me know if there's anything I should update or improve to help move this forward. |
| raise ValueError(f"Relative path '{path}' requires 'image_base_dir' to be set.") | ||
| return os.path.join(self.image_base_dir, path) | ||
|
|
||
| def _count_images_in_tokens(self, token_ids: List[int]) -> int: |
There was a problem hiding this comment.
I see the PR tried to match the how image is counted for image_grid_thw w.r.t get_rope_index. However, the current mechanism will still fail in a corner case of an image is truncated in the middle. Count index will increment by 1 and get_rope_index fail at the same place.
IMO, we should simply leverage the existing is_dropped_list and put dummy pos_ids and skip _compute_mrope_position_ids for those samples. They should get treated the same way as those exceeded length text prompts. So what we should do is:
if self._use_mrope and is_drop_list[i]:
# Don't call get_rope_index — it would crash on truncated images.
# is_drop_mask will remove this sample in the trainer.
position_ids_list.append(torch.zeros(4, seq_len, dtype=torch.long, device=device))
else:
pos_ids = self._compute_mrope_position_ids(...)
position_ids_list.append(pos_ids)
There is no harm putting the current code in place, but it's not a fix for all. Thoughts @Mr-Neutr0n?
Summary
Fixes #441
When training Qwen2.5-VL with agent-lightning + verl, the model crashes in
get_rope_indexwith a shape mismatch:fails because
llm_positionslength differs from the attention mask true-count.Root cause: In
get_train_data_batch, prompt truncation (prompt_ids[:max_prompt_length]) changes the token count, potentially removing image placeholder tokens. However,image_grid_thwis computed from the original (untruncated)image_urlslist. Whenget_rope_indexprocesses the truncated sequence, it finds fewer<|vision_start|><|image_pad|>regions thanimage_grid_thwentries, causing the position ID length to diverge from the attention mask count.Fix: After prompt truncation, count the remaining image regions in the truncated token sequence using the same
vision_start_token_id+image_token_idpattern thatget_rope_indexuses, and sliceimage_urlsto match before computingimage_grid_thw._count_images_in_tokens()helper method to detect image regions in token sequencesimage_urlswith truncated promptsTest plan
max_prompt_lengthand contain images no longer crashes inget_rope_indexmax_prompt_lengthis unaffected (no truncation, all images retained)_use_mropeisFalse)