support transformers multi-modal grpo#131
Conversation
There was a problem hiding this comment.
Code Review
This pull request introduces multimodal GRPO training capabilities, featuring a new training script for the Qwen3.5 VL model on the CLEVR dataset. Core framework enhancements include a new MultiModalAccuracyReward, a CLEVRProcessor, and a refactored vLLM sampler that better supports multimodal data and standard message formats. Review feedback identifies a logic error in the training loop where metrics are reset prematurely, as well as opportunities to optimize reward function instantiation and fix a potential typo in the reward computation.
| def compute_rewards( | ||
| trajectories: List[Dict[str, Any]], | ||
| ) -> Tuple[List[float], List[float], List[float]]: | ||
| accuracy_reward_fn = MultiModalAccuracyReward() | ||
| format_reward_fn = FormatReward() | ||
| accuracy_rewards = accuracy_reward_fn(trajectories) | ||
| format_rewards = format_reward_fn(trajectories, trajectories) | ||
| total_rewards = [a + f for a, f in zip(accuracy_rewards, format_rewards)] | ||
| return total_rewards, format_rewards, accuracy_rewards |
There was a problem hiding this comment.
There are a couple of improvements for compute_rewards:
- Inefficient Instantiation:
MultiModalAccuracyRewardandFormatRewardare instantiated on every call, which is inefficient. Consider creating them once inmain()and passing them as arguments. - Potential Bug:
format_reward_fnis called with two identical arguments (trajectories, trajectories). This is likely a typo and should probably be a single argument.
Here's a suggested refactoring that addresses both points. You would also need to instantiate the reward functions in main() and update the call site for compute_rewards.
| def compute_rewards( | |
| trajectories: List[Dict[str, Any]], | |
| ) -> Tuple[List[float], List[float], List[float]]: | |
| accuracy_reward_fn = MultiModalAccuracyReward() | |
| format_reward_fn = FormatReward() | |
| accuracy_rewards = accuracy_reward_fn(trajectories) | |
| format_rewards = format_reward_fn(trajectories, trajectories) | |
| total_rewards = [a + f for a, f in zip(accuracy_rewards, format_rewards)] | |
| return total_rewards, format_rewards, accuracy_rewards | |
| def compute_rewards( | |
| trajectories: List[Dict[str, Any]], | |
| accuracy_reward_fn: MultiModalAccuracyReward, | |
| format_reward_fn: FormatReward, | |
| ) -> Tuple[List[float], List[float], List[float]]: | |
| accuracy_rewards = accuracy_reward_fn(trajectories) | |
| format_rewards = format_reward_fn(trajectories) | |
| total_rewards = [a + f for a, f in zip(accuracy_rewards, format_rewards)] | |
| return total_rewards, format_rewards, accuracy_rewards |
| model.save(f'mm-grpo-clevr-checkpoint-{optim_step}') | ||
| log_dict = metrics.calculate() | ||
| log_dict.update(model.calculate_metric(is_training=True)) | ||
| metrics.reset() |
There was a problem hiding this comment.
The metrics.reset() call is inside the mini-batch loop, while metrics.accumulate() is called outside for the entire global batch. This leads to incorrect metric logging for all but the first mini-batch in a global step. The metrics.reset() at the start of the outer loop (line 174) is sufficient and this one should be removed.
Resolved conflicts in template/base.py: - _check_max_length: adopt upstream's cleaner _truncate_feature approach - _process_mm_messages: merge our List[Dict] content support into upstream's refactored method structure Made-with: Cursor
Summary
content: List[Dict]) in template pipeline_check_max_lengthto only truncateinput_ids/labels, keeping multimodal tensors intactNonekeys from content blocks to prevent Jinja template misparsingmulti_modal_data/mm_processor_kwargspass-through in vLLM sampling pipelineencode()kwargs inLazyDatasetfor properadd_generation_promptsupportCLEVRProcessor,MultiModalAccuracyReward, multimodal GRPO demoChanged files
template/base.py—_build_mm_messagessupports bothList[Dict]and legacy str content;_apply_chat_templatestrips null keys;_check_max_lengthonly truncates sequence fieldssampler/vllm_sampler/vllm_engine.py—sample()acceptsmulti_modal_datadict directlysampler/vllm_sampler/vllm_sampler.py— extractmulti_modal_datafrom message content blocksfor vLLM; add
mm_processor_kwargsforwardingdataset/lazy_dataset.py— forwardencode()kwargs tobatch_encode