Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 50 additions & 7 deletions modelopt/torch/utils/plugins/megatron_generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,13 +208,48 @@ def _forward_step_func(data, model):
# NOTE: we don't support traditional positional embedding. Only RoPE or YaRN are supported.
position_ids = None

output_tensor = model(
data["tokens"],
position_ids,
attention_mask,
inference_context=inference_context,
runtime_gather_output=True,
# Check if this is a VLM model (has vision inputs)
has_vision_inputs = (
("pixel_values" in data and data["pixel_values"] is not None)
or ("image_grid_thw" in data and data["image_grid_thw"] is not None)
or ("image_sizes" in data and data["image_sizes"] is not None)
Comment on lines +211 to +215
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
# Check if this is a VLM model (has vision inputs)
has_vision_inputs = (
("pixel_values" in data and data["pixel_values"] is not None)
or ("image_grid_thw" in data and data["image_grid_thw"] is not None)
or ("image_sizes" in data and data["image_sizes"] is not None)
# Check if this is a VLM model (has vision inputs)
_has_pixel_values = data.get("pixel_values") is not None
_has_image_grid_thw = data.get("image_grid_thw") is not None
_has_image_sizes = data.get("image_sizes") is not None
has_vision_inputs = _has pixel_values or _has_image_grid_thw or _has_image_sizes

)

if has_vision_inputs:
# For VLM models:
# - position_ids: [batch, seq_len] (required for RoPE with multi-modal positions)
# - attention_mask: [batch, seq_len] (simple 1D boolean mask, not 4D causal)
vlm_position_ids = (
torch.arange(seq_len, dtype=torch.long, device=device)
.unsqueeze(0)
.expand(batch_size, -1)
)
vlm_attention_mask = torch.ones((batch_size, seq_len), dtype=torch.bool, device=device)

forward_args = {
"input_ids": data["tokens"],
"position_ids": vlm_position_ids,
"attention_mask": vlm_attention_mask,
"runtime_gather_output": True,
}
# Add vision inputs
if "pixel_values" in data and data["pixel_values"] is not None:
forward_args["pixel_values"] = data["pixel_values"]
if "image_grid_thw" in data and data["image_grid_thw"] is not None:
forward_args["image_grid_thw"] = data["image_grid_thw"]
if "image_sizes" in data and data["image_sizes"] is not None:
Comment on lines +236 to +240
Copy link
Contributor

@AAnoosheh AAnoosheh Jan 13, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if "pixel_values" in data and data["pixel_values"] is not None:
forward_args["pixel_values"] = data["pixel_values"]
if "image_grid_thw" in data and data["image_grid_thw"] is not None:
forward_args["image_grid_thw"] = data["image_grid_thw"]
if "image_sizes" in data and data["image_sizes"] is not None:
if _has_pixel_values:
forward_args["pixel_values"] = data["pixel_values"]
if _has_image_grid_thw:
forward_args["image_grid_thw"] = data["image_grid_thw"]
if _has_image_sizes:

forward_args["image_sizes"] = data["image_sizes"]

output_tensor = model(**forward_args)
Comment on lines +218 to +243
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

VLM path does not support KV-cache decoding.

The VLM branch omits inference_context, so KV-cache is silently disabled for vision-language models even when enable_kv_cache=True. Additionally, vlm_position_ids always starts from 0, which would be incorrect during decode steps if KV-cache were used.

Consider either:

  1. Passing inference_context and computing the correct position offset during decode, or
  2. Explicitly disabling KV-cache when vision inputs are detected (similar to line 172-174 for sequence parallelism).
🤖 Prompt for AI Agents
In @modelopt/torch/utils/plugins/megatron_generate.py around lines 218 - 243,
The VLM branch currently omits passing inference_context and always builds
vlm_position_ids starting at 0, which disables/ breaks KV-cache during decoding;
fix by: when has_vision_inputs is true and an inference_context is provided (and
enable_kv_cache is true), include inference_context in forward_args
(forward_args["inference_context"] = inference_context) and compute
vlm_position_ids by adding the decode offset from the context (e.g., base =
getattr(inference_context, "position_offset", getattr(inference_context,
"curr_seq_len", 0)); vlm_position_ids = torch.arange(base, base + seq_len,
dtype=torch.long, device=device).unsqueeze(0).expand(batch_size, -1));
alternatively, if you prefer to disallow KV-cache for VLMs, explicitly set
forward_args["inference_context"] = None (or skip passing it) and ensure
enable_kv_cache is treated as disabled when has_vision_inputs is true.

else:
# For text-only LLM models
output_tensor = model(
data["tokens"],
position_ids,
attention_mask,
inference_context=inference_context,
runtime_gather_output=True,
)
return output_tensor, _dummy_loss_func

disable_tqdm = disable_tqdm or torch.distributed.get_rank() > 0
Expand Down Expand Up @@ -248,9 +283,17 @@ def _forward_step_func(data, model):
else:
tokens = input_ids

data_dict = {"tokens": tokens}
if pixel_values is not None:
data_dict["pixel_values"] = pixel_values
if image_grid_thw is not None:
data_dict["image_grid_thw"] = image_grid_thw
if image_sizes is not None:
data_dict["image_sizes"] = image_sizes
Comment on lines +286 to +292
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Vision inputs are passed on every generation step instead of just prefill.

Vision inputs should only be processed during the prefill phase (step 0). Passing them on every decode step is wasteful and may cause unexpected behavior in some VLM architectures.

Proposed fix: Only include vision inputs on the first step
         data_dict = {"tokens": tokens}
-        if pixel_values is not None:
-            data_dict["pixel_values"] = pixel_values
-        if image_grid_thw is not None:
-            data_dict["image_grid_thw"] = image_grid_thw
-        if image_sizes is not None:
-            data_dict["image_sizes"] = image_sizes
+        # Vision inputs should only be processed during prefill (step 0)
+        if step == 0:
+            if pixel_values is not None:
+                data_dict["pixel_values"] = pixel_values
+            if image_grid_thw is not None:
+                data_dict["image_grid_thw"] = image_grid_thw
+            if image_sizes is not None:
+                data_dict["image_sizes"] = image_sizes
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
data_dict = {"tokens": tokens}
if pixel_values is not None:
data_dict["pixel_values"] = pixel_values
if image_grid_thw is not None:
data_dict["image_grid_thw"] = image_grid_thw
if image_sizes is not None:
data_dict["image_sizes"] = image_sizes
data_dict = {"tokens": tokens}
# Vision inputs should only be processed during prefill (step 0)
if step == 0:
if pixel_values is not None:
data_dict["pixel_values"] = pixel_values
if image_grid_thw is not None:
data_dict["image_grid_thw"] = image_grid_thw
if image_sizes is not None:
data_dict["image_sizes"] = image_sizes
🤖 Prompt for AI Agents
In @modelopt/torch/utils/plugins/megatron_generate.py around lines 286 - 292,
The vision inputs (pixel_values, image_grid_thw, image_sizes) are being added to
data_dict on every decode step; change the logic so these keys are only added
during the prefill/first generation step (e.g., when step == 0 or when an
is_prefill flag is true). Locate the block building data_dict (symbols:
data_dict, tokens, pixel_values, image_grid_thw, image_sizes) inside the
generation loop/function and wrap the conditional additions of pixel_values,
image_grid_thw, and image_sizes so they execute only for the initial prefill
step.


list_of_logits = get_forward_backward_func()(
forward_step_func=_forward_step_func,
data_iterator=[{"tokens": tokens}],
data_iterator=[data_dict],
model=model,
num_microbatches=1,
seq_length=tokens.shape[-1],
Expand Down
Loading