diff --git a/modelopt/torch/utils/plugins/megatron_generate.py b/modelopt/torch/utils/plugins/megatron_generate.py index d542d935a..ba21a0dea 100644 --- a/modelopt/torch/utils/plugins/megatron_generate.py +++ b/modelopt/torch/utils/plugins/megatron_generate.py @@ -208,13 +208,48 @@ def _forward_step_func(data, model): # NOTE: we don't support traditional positional embedding. Only RoPE or YaRN are supported. position_ids = None - output_tensor = model( - data["tokens"], - position_ids, - attention_mask, - inference_context=inference_context, - runtime_gather_output=True, + # Check if this is a VLM model (has vision inputs) + has_vision_inputs = ( + ("pixel_values" in data and data["pixel_values"] is not None) + or ("image_grid_thw" in data and data["image_grid_thw"] is not None) + or ("image_sizes" in data and data["image_sizes"] is not None) ) + + if has_vision_inputs: + # For VLM models: + # - position_ids: [batch, seq_len] (required for RoPE with multi-modal positions) + # - attention_mask: [batch, seq_len] (simple 1D boolean mask, not 4D causal) + vlm_position_ids = ( + torch.arange(seq_len, dtype=torch.long, device=device) + .unsqueeze(0) + .expand(batch_size, -1) + ) + vlm_attention_mask = torch.ones((batch_size, seq_len), dtype=torch.bool, device=device) + + forward_args = { + "input_ids": data["tokens"], + "position_ids": vlm_position_ids, + "attention_mask": vlm_attention_mask, + "runtime_gather_output": True, + } + # Add vision inputs + if "pixel_values" in data and data["pixel_values"] is not None: + forward_args["pixel_values"] = data["pixel_values"] + if "image_grid_thw" in data and data["image_grid_thw"] is not None: + forward_args["image_grid_thw"] = data["image_grid_thw"] + if "image_sizes" in data and data["image_sizes"] is not None: + forward_args["image_sizes"] = data["image_sizes"] + + output_tensor = model(**forward_args) + else: + # For text-only LLM models + output_tensor = model( + data["tokens"], + position_ids, + attention_mask, + inference_context=inference_context, + runtime_gather_output=True, + ) return output_tensor, _dummy_loss_func disable_tqdm = disable_tqdm or torch.distributed.get_rank() > 0 @@ -248,9 +283,17 @@ def _forward_step_func(data, model): else: tokens = input_ids + data_dict = {"tokens": tokens} + if pixel_values is not None: + data_dict["pixel_values"] = pixel_values + if image_grid_thw is not None: + data_dict["image_grid_thw"] = image_grid_thw + if image_sizes is not None: + data_dict["image_sizes"] = image_sizes + list_of_logits = get_forward_backward_func()( forward_step_func=_forward_step_func, - data_iterator=[{"tokens": tokens}], + data_iterator=[data_dict], model=model, num_microbatches=1, seq_length=tokens.shape[-1],