-
Notifications
You must be signed in to change notification settings - Fork 239
Support megatron generate for vlm #773
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -208,13 +208,48 @@ def _forward_step_func(data, model): | |||||||||||||||||||||||||||||||||
| # NOTE: we don't support traditional positional embedding. Only RoPE or YaRN are supported. | ||||||||||||||||||||||||||||||||||
| position_ids = None | ||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||
| output_tensor = model( | ||||||||||||||||||||||||||||||||||
| data["tokens"], | ||||||||||||||||||||||||||||||||||
| position_ids, | ||||||||||||||||||||||||||||||||||
| attention_mask, | ||||||||||||||||||||||||||||||||||
| inference_context=inference_context, | ||||||||||||||||||||||||||||||||||
| runtime_gather_output=True, | ||||||||||||||||||||||||||||||||||
| # Check if this is a VLM model (has vision inputs) | ||||||||||||||||||||||||||||||||||
| has_vision_inputs = ( | ||||||||||||||||||||||||||||||||||
| ("pixel_values" in data and data["pixel_values"] is not None) | ||||||||||||||||||||||||||||||||||
| or ("image_grid_thw" in data and data["image_grid_thw"] is not None) | ||||||||||||||||||||||||||||||||||
| or ("image_sizes" in data and data["image_sizes"] is not None) | ||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||
| if has_vision_inputs: | ||||||||||||||||||||||||||||||||||
| # For VLM models: | ||||||||||||||||||||||||||||||||||
| # - position_ids: [batch, seq_len] (required for RoPE with multi-modal positions) | ||||||||||||||||||||||||||||||||||
| # - attention_mask: [batch, seq_len] (simple 1D boolean mask, not 4D causal) | ||||||||||||||||||||||||||||||||||
| vlm_position_ids = ( | ||||||||||||||||||||||||||||||||||
| torch.arange(seq_len, dtype=torch.long, device=device) | ||||||||||||||||||||||||||||||||||
| .unsqueeze(0) | ||||||||||||||||||||||||||||||||||
| .expand(batch_size, -1) | ||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||
| vlm_attention_mask = torch.ones((batch_size, seq_len), dtype=torch.bool, device=device) | ||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||
| forward_args = { | ||||||||||||||||||||||||||||||||||
| "input_ids": data["tokens"], | ||||||||||||||||||||||||||||||||||
| "position_ids": vlm_position_ids, | ||||||||||||||||||||||||||||||||||
| "attention_mask": vlm_attention_mask, | ||||||||||||||||||||||||||||||||||
| "runtime_gather_output": True, | ||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||
| # Add vision inputs | ||||||||||||||||||||||||||||||||||
| if "pixel_values" in data and data["pixel_values"] is not None: | ||||||||||||||||||||||||||||||||||
| forward_args["pixel_values"] = data["pixel_values"] | ||||||||||||||||||||||||||||||||||
| if "image_grid_thw" in data and data["image_grid_thw"] is not None: | ||||||||||||||||||||||||||||||||||
| forward_args["image_grid_thw"] = data["image_grid_thw"] | ||||||||||||||||||||||||||||||||||
| if "image_sizes" in data and data["image_sizes"] is not None: | ||||||||||||||||||||||||||||||||||
|
Comment on lines
+236
to
+240
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||||||||||||||||||||||||||||
| forward_args["image_sizes"] = data["image_sizes"] | ||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||
| output_tensor = model(**forward_args) | ||||||||||||||||||||||||||||||||||
|
Comment on lines
+218
to
+243
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. VLM path does not support KV-cache decoding. The VLM branch omits Consider either:
🤖 Prompt for AI Agents |
||||||||||||||||||||||||||||||||||
| else: | ||||||||||||||||||||||||||||||||||
| # For text-only LLM models | ||||||||||||||||||||||||||||||||||
| output_tensor = model( | ||||||||||||||||||||||||||||||||||
| data["tokens"], | ||||||||||||||||||||||||||||||||||
| position_ids, | ||||||||||||||||||||||||||||||||||
| attention_mask, | ||||||||||||||||||||||||||||||||||
| inference_context=inference_context, | ||||||||||||||||||||||||||||||||||
| runtime_gather_output=True, | ||||||||||||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||||||||||||
| return output_tensor, _dummy_loss_func | ||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||
| disable_tqdm = disable_tqdm or torch.distributed.get_rank() > 0 | ||||||||||||||||||||||||||||||||||
|
|
@@ -248,9 +283,17 @@ def _forward_step_func(data, model): | |||||||||||||||||||||||||||||||||
| else: | ||||||||||||||||||||||||||||||||||
| tokens = input_ids | ||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||
| data_dict = {"tokens": tokens} | ||||||||||||||||||||||||||||||||||
| if pixel_values is not None: | ||||||||||||||||||||||||||||||||||
| data_dict["pixel_values"] = pixel_values | ||||||||||||||||||||||||||||||||||
| if image_grid_thw is not None: | ||||||||||||||||||||||||||||||||||
| data_dict["image_grid_thw"] = image_grid_thw | ||||||||||||||||||||||||||||||||||
| if image_sizes is not None: | ||||||||||||||||||||||||||||||||||
| data_dict["image_sizes"] = image_sizes | ||||||||||||||||||||||||||||||||||
|
Comment on lines
+286
to
+292
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Vision inputs are passed on every generation step instead of just prefill. Vision inputs should only be processed during the prefill phase (step 0). Passing them on every decode step is wasteful and may cause unexpected behavior in some VLM architectures. Proposed fix: Only include vision inputs on the first step data_dict = {"tokens": tokens}
- if pixel_values is not None:
- data_dict["pixel_values"] = pixel_values
- if image_grid_thw is not None:
- data_dict["image_grid_thw"] = image_grid_thw
- if image_sizes is not None:
- data_dict["image_sizes"] = image_sizes
+ # Vision inputs should only be processed during prefill (step 0)
+ if step == 0:
+ if pixel_values is not None:
+ data_dict["pixel_values"] = pixel_values
+ if image_grid_thw is not None:
+ data_dict["image_grid_thw"] = image_grid_thw
+ if image_sizes is not None:
+ data_dict["image_sizes"] = image_sizes📝 Committable suggestion
Suggested change
🤖 Prompt for AI Agents |
||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||
| list_of_logits = get_forward_backward_func()( | ||||||||||||||||||||||||||||||||||
| forward_step_func=_forward_step_func, | ||||||||||||||||||||||||||||||||||
| data_iterator=[{"tokens": tokens}], | ||||||||||||||||||||||||||||||||||
| data_iterator=[data_dict], | ||||||||||||||||||||||||||||||||||
| model=model, | ||||||||||||||||||||||||||||||||||
| num_microbatches=1, | ||||||||||||||||||||||||||||||||||
| seq_length=tokens.shape[-1], | ||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.