Skip to content

support transformers multi-modal grpo#131

Open
hjh0119 wants to merge 3 commits intomodelscope:mainfrom
hjh0119:refactor-mm
Open

support transformers multi-modal grpo#131
hjh0119 wants to merge 3 commits intomodelscope:mainfrom
hjh0119:refactor-mm

Conversation

@hjh0119
Copy link
Copy Markdown
Collaborator

@hjh0119 hjh0119 commented Mar 29, 2026

Summary

  • Support transformers-standard multimodal message format (content: List[Dict]) in template pipeline
  • Fix _check_max_length to only truncate input_ids/labels, keeping multimodal tensors intact
  • Clean Arrow-serialized None keys from content blocks to prevent Jinja template misparsing
  • Add multi_modal_data / mm_processor_kwargs pass-through in vLLM sampling pipeline
  • Forward encode() kwargs in LazyDataset for proper add_generation_prompt support
  • Add CLEVRProcessor, MultiModalAccuracyReward, multimodal GRPO demo

Changed files

  • template/base.py_build_mm_messages supports both List[Dict] and legacy str content;
    _apply_chat_template strips null keys; _check_max_length only truncates sequence fields
  • sampler/vllm_sampler/vllm_engine.pysample() accepts multi_modal_data dict directly
  • sampler/vllm_sampler/vllm_sampler.py — extract multi_modal_data from message content blocks
    for vLLM; add mm_processor_kwargs forwarding
  • dataset/lazy_dataset.py — forward encode() kwargs to batch_encode

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces multimodal GRPO training capabilities, featuring a new training script for the Qwen3.5 VL model on the CLEVR dataset. Core framework enhancements include a new MultiModalAccuracyReward, a CLEVRProcessor, and a refactored vLLM sampler that better supports multimodal data and standard message formats. Review feedback identifies a logic error in the training loop where metrics are reset prematurely, as well as opportunities to optimize reward function instantiation and fix a potential typo in the reward computation.

Comment on lines +85 to +93
def compute_rewards(
trajectories: List[Dict[str, Any]],
) -> Tuple[List[float], List[float], List[float]]:
accuracy_reward_fn = MultiModalAccuracyReward()
format_reward_fn = FormatReward()
accuracy_rewards = accuracy_reward_fn(trajectories)
format_rewards = format_reward_fn(trajectories, trajectories)
total_rewards = [a + f for a, f in zip(accuracy_rewards, format_rewards)]
return total_rewards, format_rewards, accuracy_rewards
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

There are a couple of improvements for compute_rewards:

  1. Inefficient Instantiation: MultiModalAccuracyReward and FormatReward are instantiated on every call, which is inefficient. Consider creating them once in main() and passing them as arguments.
  2. Potential Bug: format_reward_fn is called with two identical arguments (trajectories, trajectories). This is likely a typo and should probably be a single argument.

Here's a suggested refactoring that addresses both points. You would also need to instantiate the reward functions in main() and update the call site for compute_rewards.

Suggested change
def compute_rewards(
trajectories: List[Dict[str, Any]],
) -> Tuple[List[float], List[float], List[float]]:
accuracy_reward_fn = MultiModalAccuracyReward()
format_reward_fn = FormatReward()
accuracy_rewards = accuracy_reward_fn(trajectories)
format_rewards = format_reward_fn(trajectories, trajectories)
total_rewards = [a + f for a, f in zip(accuracy_rewards, format_rewards)]
return total_rewards, format_rewards, accuracy_rewards
def compute_rewards(
trajectories: List[Dict[str, Any]],
accuracy_reward_fn: MultiModalAccuracyReward,
format_reward_fn: FormatReward,
) -> Tuple[List[float], List[float], List[float]]:
accuracy_rewards = accuracy_reward_fn(trajectories)
format_rewards = format_reward_fn(trajectories)
total_rewards = [a + f for a, f in zip(accuracy_rewards, format_rewards)]
return total_rewards, format_rewards, accuracy_rewards

model.save(f'mm-grpo-clevr-checkpoint-{optim_step}')
log_dict = metrics.calculate()
log_dict.update(model.calculate_metric(is_training=True))
metrics.reset()
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The metrics.reset() call is inside the mini-batch loop, while metrics.accumulate() is called outside for the entire global batch. This leads to incorrect metric logging for all but the first mini-batch in a global step. The metrics.reset() at the start of the outer loop (line 174) is sufficient and this one should be removed.

hjh0119 added 2 commits March 30, 2026 16:09
Resolved conflicts in template/base.py:
- _check_max_length: adopt upstream's cleaner _truncate_feature approach
- _process_mm_messages: merge our List[Dict] content support into
  upstream's refactored method structure

Made-with: Cursor
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant