From fad818fdbc4f32a8aede3e48986cd1f97fadd40a Mon Sep 17 00:00:00 2001 From: root Date: Sun, 29 Mar 2026 21:23:21 +0800 Subject: [PATCH 1/2] refactor mm --- cookbook/rl/mm_grpo.py | 235 ++++++++++++++++++ src/twinkle/dataset/base.py | 13 + src/twinkle/dataset/lazy_dataset.py | 7 +- src/twinkle/preprocessor/__init__.py | 1 + src/twinkle/preprocessor/mm.py | 67 +++++ src/twinkle/reward/__init__.py | 1 + src/twinkle/reward/mm_reward.py | 70 ++++++ .../sampler/vllm_sampler/vllm_engine.py | 38 +-- .../sampler/vllm_sampler/vllm_sampler.py | 97 ++++---- src/twinkle/template/base.py | 91 +++++-- 10 files changed, 520 insertions(+), 100 deletions(-) create mode 100644 cookbook/rl/mm_grpo.py create mode 100644 src/twinkle/preprocessor/mm.py create mode 100644 src/twinkle/reward/mm_reward.py diff --git a/cookbook/rl/mm_grpo.py b/cookbook/rl/mm_grpo.py new file mode 100644 index 00000000..dcb3bac6 --- /dev/null +++ b/cookbook/rl/mm_grpo.py @@ -0,0 +1,235 @@ +"""Multimodal GRPO training demo with Qwen3.5 VL model on CLEVR dataset. + +This script demonstrates on-policy GRPO (Group Relative Policy Optimization) +for visual question answering using: + - Model: Qwen3.5-2B (vision-language model) + - Dataset: AI-ModelScope/clevr_cogen_a_train (CLEVR visual reasoning) + - Rewards: accuracy (answer correctness) + format (/ tags) + - Template: Qwen3_5Template (handles vision token embedding merge) + +Architecture: + - Separate GPU groups for training model and vLLM sampler (Ray mode) + - LoRA fine-tuning with NCCL weight sync between model and sampler + - GRPO loss with PPO-style clipping (epsilon=0.2) + +Usage: + python mm_grpo.py + +Environment variables: + MODEL_ID : Model path (default: ms://Qwen/Qwen3.5-2B) + MODEL_GPUS : GPUs for training model (default: 2) + SAMPLER_GPUS : GPUs for vLLM sampler (default: 1) + NUM_GENERATIONS: Completions per prompt for GRPO grouping (default: 4) + MAX_NEW_TOKENS : Max generation length (default: 4096) + LR : Learning rate (default: 5e-5) + MAX_STEPS : Total optimization steps (default: 200) + BATCH_SIZE : Global prompt-level batch size (default: 1) + MINI_BATCH_SIZE: Global completion-level mini-batch size (default: 4) + MICRO_BATCH_SIZE: Per-device micro-batch size (default: 1) + DATA_SLICE : Number of dataset samples to use (default: 2000) + SAVE_STEPS : Checkpoint save interval (default: 50) +""" +import os +from typing import Any, Dict, List, Tuple + +from peft import LoraConfig + +import twinkle +from twinkle import DeviceGroup, DeviceMesh, get_device_placement, get_logger +from twinkle.advantage import GRPOAdvantage +from twinkle.checkpoint_engine import CheckpointEngineManager +from twinkle.data_format import SamplingParams +from twinkle.dataloader import DataLoader +from twinkle.dataset import DatasetMeta, LazyDataset +from twinkle.metric import CompletionRewardMetric +from twinkle.model import TransformersModel +from twinkle.preprocessor.mm import CLEVRProcessor +from twinkle.processor import InputProcessor +from twinkle.reward import FormatReward, MultiModalAccuracyReward +from twinkle.sampler import vLLMSampler +from twinkle.template import Qwen3_5Template + +logger = get_logger() + +MODEL_ID = os.environ.get('MODEL_ID', 'ms://Qwen/Qwen3.5-2B') + +MODEL_GPUS = int(os.environ.get('MODEL_GPUS', 2)) +SAMPLER_GPUS = int(os.environ.get('SAMPLER_GPUS', 1)) +NUM_GPUS = MODEL_GPUS + SAMPLER_GPUS + +NUM_GENERATIONS = int(os.environ.get('NUM_GENERATIONS', 4)) +MAX_NEW_TOKENS = int(os.environ.get('MAX_NEW_TOKENS', 4096)) +LEARNING_RATE = float(os.environ.get('LR', 5e-5)) +MAX_STEPS = int(os.environ.get('MAX_STEPS', 200)) +BATCH_SIZE = int(os.environ.get('BATCH_SIZE', 1)) +MINI_BATCH_SIZE = int(os.environ.get('MINI_BATCH_SIZE', 4)) +MICRO_BATCH_SIZE = int(os.environ.get('MICRO_BATCH_SIZE', 1)) +GRADIENT_ACCUMULATION_STEPS = int(os.environ.get('GRADIENT_ACCUMULATION_STEPS', 1)) +DATA_SLICE = int(os.environ.get('DATA_SLICE', 2000)) +ADAPTER_NAME = 'default' +SAVE_STEPS = int(os.environ.get('SAVE_STEPS', 50)) + + +def create_clevr_dataset(): + dataset = LazyDataset( + DatasetMeta('ms://AI-ModelScope/clevr_cogen_a_train', split='train', + data_slice=range(DATA_SLICE)), + ) + dataset.cast_column('image', decode=False) + dataset.set_template('Qwen3_5Template', model_id=MODEL_ID, max_length=4096) + dataset.map(CLEVRProcessor(), remove_columns=['image', 'problem', 'solution']) + dataset.encode(add_generation_prompt=True) + return dataset + + +def compute_rewards( + trajectories: List[Dict[str, Any]], +) -> Tuple[List[float], List[float], List[float]]: + accuracy_reward_fn = MultiModalAccuracyReward() + format_reward_fn = FormatReward() + accuracy_rewards = accuracy_reward_fn(trajectories) + format_rewards = format_reward_fn(trajectories, trajectories) + total_rewards = [a + f for a, f in zip(accuracy_rewards, format_rewards)] + return total_rewards, format_rewards, accuracy_rewards + + +def main(): + device_groups = [ + DeviceGroup(name='model', ranks=list(range(MODEL_GPUS)), device_type='GPU'), + DeviceGroup( + name='sampler', + ranks=list(range(MODEL_GPUS, NUM_GPUS)), + device_type='GPU', + gpus_per_worker=SAMPLER_GPUS, + ), + ] + model_mesh = DeviceMesh.from_sizes(world_size=MODEL_GPUS, dp_size=MODEL_GPUS) + sampler_mesh = DeviceMesh.from_sizes(world_size=1, dp_size=1) + twinkle.initialize(mode='ray', nproc_per_node=NUM_GPUS, groups=device_groups, lazy_collect=False) + + lora_config = LoraConfig( + target_modules=[ + 'q_proj', 'k_proj', 'v_proj', 'o_proj', + 'gate_proj', 'up_proj', 'down_proj', + 'in_proj_qkv', 'in_proj_z', 'in_proj_a', 'in_proj_b', 'out_proj', + ], + ) + + from modelscope import Qwen3_5ForConditionalGeneration + model = TransformersModel( + model_id=MODEL_ID, + model_cls=Qwen3_5ForConditionalGeneration, + device_mesh=model_mesh, + remote_group='model', + ) + + model.add_adapter_to_model(ADAPTER_NAME, lora_config, gradient_accumulation_steps=GRADIENT_ACCUMULATION_STEPS) + model.set_optimizer('AdamW', lr=LEARNING_RATE) + model.set_lr_scheduler('CosineAnnealingLR', T_max=MAX_STEPS, eta_min=0) + model.set_loss('GRPOLoss', epsilon=0.2) + model.set_processor(InputProcessor) + model.set_template('Qwen3_5Template', model_id=MODEL_ID) + + sampler = vLLMSampler( + model_id=MODEL_ID, + engine_args={ + 'gpu_memory_utilization': 0.8, + 'max_model_len': 8192, + 'max_lora_rank': 8, + 'enable_lora': True, + 'limit_mm_per_prompt': {'image': 1, 'video': 0}, + 'mm_processor_cache_gb': 0, + }, + device_mesh=sampler_mesh, + remote_group='sampler', + ) + sampler.set_template(Qwen3_5Template, model_id=MODEL_ID) + + ckpt_manager = CheckpointEngineManager(model=model, sampler=sampler) + + GLOBAL_BATCH_SIZE = BATCH_SIZE * GRADIENT_ACCUMULATION_STEPS + dataloader = DataLoader( + dataset=create_clevr_dataset, + batch_size=GLOBAL_BATCH_SIZE, + min_batch_size=GLOBAL_BATCH_SIZE, + device_mesh=model_mesh, + remote_group='model', + ) + advantage_fn = GRPOAdvantage() + metrics = CompletionRewardMetric() + + sampling_params = SamplingParams( + max_tokens=MAX_NEW_TOKENS, + num_samples=1, + logprobs=1, + temperature=1.0, + ) + + optim_step = 0 + logger.info(get_device_placement()) + + for batch in dataloader: + if optim_step >= MAX_STEPS: + break + metrics.reset() + global_prompts = batch if isinstance(batch, list) else [batch] + + ckpt_manager.sync_weights(merge_and_sync=False) + sampler.reset_prefix_cache() + sample_responses = sampler.sample( + global_prompts * NUM_GENERATIONS, + sampling_params, + ) + + all_input_data: List[Dict[str, Any]] = [] + all_old_logps: List[List[float]] = [] + all_completion_lengths: List[int] = [] + for sample_response in sample_responses: + for sequence in sample_response.sequences: + all_input_data.append(sequence.new_input_feature) + all_old_logps.append([logprob[0][1] for logprob in sequence.logprobs]) + all_completion_lengths.append(len(sequence.tokens)) + + total_rewards, format_rewards, accuracy_rewards = compute_rewards(all_input_data) + metrics.accumulate( + completion_lengths=all_completion_lengths, + rewards={ + 'total': total_rewards, + 'format': format_rewards, + 'accuracy': accuracy_rewards, + }, + ) + + advantages = advantage_fn(total_rewards, num_generations=NUM_GENERATIONS, scale='group').tolist() + + total_completions = len(all_input_data) + for mb_start in range(0, total_completions, MINI_BATCH_SIZE): + mb_end = min(mb_start + MINI_BATCH_SIZE, total_completions) + mb_inputs = all_input_data[mb_start:mb_end] + mb_old_logps = all_old_logps[mb_start:mb_end] + mb_advantages = advantages[mb_start:mb_end] + + model.forward_backward( + inputs=mb_inputs, + old_logps=mb_old_logps, + advantages=mb_advantages, + micro_batch_size=MICRO_BATCH_SIZE, + ) + model.clip_grad_and_step() + optim_step += 1 + + if optim_step >= MAX_STEPS: + break + if optim_step % SAVE_STEPS == 0: + model.save(f'mm-grpo-clevr-checkpoint-{optim_step}') + log_dict = metrics.calculate() + log_dict.update(model.calculate_metric(is_training=True)) + metrics.reset() + logger.info(f'[Step {optim_step}/{MAX_STEPS}] {log_dict}') + + logger.info(f'Training completed. optim_steps={optim_step}') + model.save('mm-grpo-clevr-checkpoint') + + +if __name__ == '__main__': + main() diff --git a/src/twinkle/dataset/base.py b/src/twinkle/dataset/base.py index 00adc984..266c9534 100644 --- a/src/twinkle/dataset/base.py +++ b/src/twinkle/dataset/base.py @@ -168,6 +168,19 @@ def _load_dataset(dataset_meta: DatasetMeta, **kwargs): dataset = dataset.select(iter_list) return dataset + @remote_function() + def cast_column(self, column: str, decode: bool = True) -> None: + """Cast an image/audio column's decode mode. + + Useful for setting ``decode=False`` before ``.map()`` to keep media + as raw bytes and avoid expensive PIL encode/decode round-trips. + """ + from datasets import Image as ImageFeature + for key in list(self.datasets.keys()): + self.datasets[key] = self.datasets[key].cast_column(column, ImageFeature(decode=decode)) + if len(self.datasets) == 1: + self.dataset = self.datasets[next(iter(self.datasets.keys()))] + @remote_function() def map(self, preprocess_func: Union[Preprocessor, Callable, str, Type[Preprocessor]], diff --git a/src/twinkle/dataset/lazy_dataset.py b/src/twinkle/dataset/lazy_dataset.py index e7c8d4a6..3f2b58c3 100644 --- a/src/twinkle/dataset/lazy_dataset.py +++ b/src/twinkle/dataset/lazy_dataset.py @@ -22,6 +22,7 @@ def encode(self, **kwargs): assert self.template.truncation_strategy != 'split', ('Lazy tokenize does not support ' 'truncation_strategy==`split`') self.do_encode = True + self.encode_kwargs = kwargs @remote_function() def check(self, **kwargs): @@ -33,7 +34,11 @@ def __getitem__(self, idx): item = self.dataset[idx] # may raise errors if self.do_encode: - item = self.template.batch_encode([item])[0] + encoded = self.template.batch_encode([item], **self.encode_kwargs)[0] + for key in item: + if key not in encoded: + encoded[key] = item[key] + item = encoded elif self.do_check: item = self.template.check(item) return item diff --git a/src/twinkle/preprocessor/__init__.py b/src/twinkle/preprocessor/__init__.py index 13b52d99..fef05a52 100644 --- a/src/twinkle/preprocessor/__init__.py +++ b/src/twinkle/preprocessor/__init__.py @@ -2,3 +2,4 @@ from .base import DataFilter, Preprocessor from .llm import (AlpacaProcessor, CompetitionMathGRPOProcessor, CompetitionMathProcessor, CountdownProcessor, GSM8KProcessor, SelfCognitionProcessor) +from .mm import CLEVRProcessor diff --git a/src/twinkle/preprocessor/mm.py b/src/twinkle/preprocessor/mm.py new file mode 100644 index 00000000..0018bd50 --- /dev/null +++ b/src/twinkle/preprocessor/mm.py @@ -0,0 +1,67 @@ +# Copyright (c) ModelScope Contributors. All rights reserved. +import re +from typing import Any, Dict, List, Optional + +from twinkle.data_format import Message, Trajectory +from .base import Preprocessor + + +class CLEVRProcessor(Preprocessor): + """Preprocessor for CLEVR-CoGenT visual reasoning dataset (prompt-only, for GRPO). + + Dataset fields: image (PIL.Image or dict), problem (str), solution (str with tags) + Produces prompt-only trajectories with image in the user message and + ground truth stored in user_data for reward computation. + + For fast ``.map()`` performance, call ``dataset.cast_column('image', decode=False)`` + before mapping so that images stay as Arrow-native bytes dicts. + """ + + DEFAULT_SYSTEM = ('A conversation between User and Assistant. The user asks a question, ' + 'and the Assistant solves it. The assistant first thinks about the reasoning ' + 'process in the mind and then provides the user with the answer. The reasoning ' + 'process and answer are enclosed within and ' + 'tags, respectively, i.e., reasoning process here ' + ' answer here ') + + def __init__(self, system: Optional[str] = None): + self.system = system if system is not None else self.DEFAULT_SYSTEM + + @staticmethod + def extract_ground_truth(solution: str) -> str: + """Extract answer text from ... tags.""" + match = re.search(r'\s*(.*?)\s*', solution, re.DOTALL) + return match.group(1).strip() if match else solution.strip() + + def __call__(self, rows: Dict[str, List[Any]]) -> Dict[str, List[Any]]: + rows = self.map_col_to_row(rows) + rows = [self.preprocess(row) for row in rows] + rows = self.map_row_to_col(rows) + return rows + + def preprocess(self, row) -> Trajectory: + image = row['image'] + problem = row['problem'] + solution = row.get('solution', '') + ground_truth = self.extract_ground_truth(solution) + + messages = [ + Message(role='system', content=[{ + 'type': 'text', + 'text': self.system + }]), + Message(role='user', content=[ + { + 'type': 'image', + 'image': image + }, + { + 'type': 'text', + 'text': problem + }, + ]), + ] + return Trajectory( + messages=messages, + user_data=[('ground_truth', ground_truth), ('solution', solution)], + ) diff --git a/src/twinkle/reward/__init__.py b/src/twinkle/reward/__init__.py index 48193004..21ba55d2 100644 --- a/src/twinkle/reward/__init__.py +++ b/src/twinkle/reward/__init__.py @@ -3,3 +3,4 @@ from .format_reward import FormatReward from .gsm8k import GSM8KAccuracyReward, GSM8KFormatReward from .math_reward import MathReward +from .mm_reward import MultiModalAccuracyReward diff --git a/src/twinkle/reward/mm_reward.py b/src/twinkle/reward/mm_reward.py new file mode 100644 index 00000000..a2465d31 --- /dev/null +++ b/src/twinkle/reward/mm_reward.py @@ -0,0 +1,70 @@ +# Copyright (c) ModelScope Contributors. All rights reserved. +import re +from typing import Any, Dict, List + +from twinkle.reward.base import Reward + + +class MultiModalAccuracyReward(Reward): + """Accuracy reward for multimodal VQA tasks (e.g. CLEVR). + + Compares the model's answer (inside tags) against + the ground truth stored in user_data['ground_truth']. + Falls back to math_verify symbolic verification when available. + Returns 1.0 for correct, 0.0 for incorrect. + """ + + @staticmethod + def extract_answer(text: str) -> str: + """Extract the answer from ... tags.""" + match = re.search(r'\s*(.*?)\s*', text, re.DOTALL) + return match.group(1).strip() if match else '' + + def __call__(self, trajectories: List[Dict[str, Any]], **kwargs) -> List[float]: + rewards = [] + for trajectory in trajectories: + messages = trajectory.get('messages', []) + completion = '' + for msg in reversed(messages): + if msg.get('role') == 'assistant': + content = msg.get('content', '') + if isinstance(content, str): + completion = content + elif isinstance(content, list): + completion = ' '.join(part.get('text', '') for part in content if part.get('type') == 'text') + break + + user_data = trajectory.get('user_data', []) + gt = '' + solution = '' + for item in user_data: + if item[0] == 'ground_truth': + gt = str(item[1]) + elif item[0] == 'solution': + solution = str(item[1]) + + predicted = self.extract_answer(completion) + reward = 0.0 + + # Try symbolic math verification first + try: + from math_verify import parse, verify + answer = parse(completion) + if float(verify(answer, parse(solution or gt))) > 0: + reward = 1.0 + except Exception: + pass + + # Fallback: string matching + if reward == 0.0 and predicted and gt: + if predicted.strip().lower() == gt.strip().lower(): + reward = 1.0 + else: + try: + if abs(float(predicted) - float(gt)) < 1e-5: + reward = 1.0 + except (ValueError, OverflowError): + pass + + rewards.append(reward) + return rewards diff --git a/src/twinkle/sampler/vllm_sampler/vllm_engine.py b/src/twinkle/sampler/vllm_sampler/vllm_engine.py index 911b11c3..93ce0d18 100644 --- a/src/twinkle/sampler/vllm_sampler/vllm_engine.py +++ b/src/twinkle/sampler/vllm_sampler/vllm_engine.py @@ -183,31 +183,25 @@ async def sample(self, request_id: Optional[str] = None, priority: int = 0, *, - images: Optional[List[Any]] = None, - videos: Optional[List[Any]] = None, + multi_modal_data: Optional[Dict[str, Any]] = None, + mm_processor_kwargs: Optional[Dict[str, Any]] = None, **kwargs) -> SampleResponse: """ Sample completions from the model. - This is the core API aligned with tinker's sampling interface. - Args: prompt_token_ids: Input token IDs. sampling_params: Sampling parameters (tinker.types.SamplingParams or dict). - num_samples: Number of samples to generate. - logprobs: Whether to return log probabilities for generated tokens. - include_prompt_logprobs: Whether to compute logprobs on prompt tokens. - topk_prompt_logprobs: If > 0, returns top-k logprobs for each prompt token. lora_request: LoRARequest for sampling. request_id: Optional request ID for tracking. priority: Request priority (higher = more urgent). - images: Optional list of images for multimodal models. - Can be PIL.Image, file paths, URLs, or bytes. - videos: Optional list of videos for multimodal models. - Can be file paths or list of frames. + multi_modal_data: Optional dict of multimodal data for vLLM + (e.g. ``{'image': [PIL_Image, ...], 'video': [...]}``) + mm_processor_kwargs: Optional kwargs forwarded to vLLM's multimodal processor + (e.g. ``{'do_resize': False}``) Returns: - tinker.types.SampleResponse containing sequences and optionally prompt_logprobs. + SampleResponse containing sequences and optionally prompt_logprobs. """ from vllm.inputs import TokensPrompt @@ -222,21 +216,11 @@ async def sample(self, if request_id is None: request_id = uuid.uuid4().hex - # Build multi_modal_data if images or videos provided - multi_modal_data = {} - if images: - multi_modal_data['image'] = images - if videos: - multi_modal_data['video'] = videos - - # Build prompt (with or without multimodal data) + prompt = TokensPrompt(prompt_token_ids=prompt_token_ids) if multi_modal_data: - prompt = TokensPrompt( - prompt_token_ids=prompt_token_ids, - multi_modal_data=multi_modal_data, - ) - else: - prompt = TokensPrompt(prompt_token_ids=prompt_token_ids) + prompt['multi_modal_data'] = multi_modal_data + if mm_processor_kwargs: + prompt['mm_processor_kwargs'] = mm_processor_kwargs if lora_request is not None and not self.enable_lora: logger.warning('lora_request provided but enable_lora is ' diff --git a/src/twinkle/sampler/vllm_sampler/vllm_sampler.py b/src/twinkle/sampler/vllm_sampler/vllm_sampler.py index 915e012f..acd8df59 100644 --- a/src/twinkle/sampler/vllm_sampler/vllm_sampler.py +++ b/src/twinkle/sampler/vllm_sampler/vllm_sampler.py @@ -148,52 +148,17 @@ def encode_trajectory_for_vllm(self, trajectory: Trajectory, adapter_name: str = '', add_generation_prompt=True) -> InputFeature: - """Encode trajectory for vLLM - does not expand image tokens. + """Encode trajectory for vLLM. - Args: - trajectory: The trajectory to encode. - adapter_name: Optional LoRA adapter name. - - Returns: - InputFeature with input_ids suitable for vLLM (unexpanded image tokens). + Messages should already use transformers standard format (content is List[Dict]). + ``batch_encode`` preprocesses media refs in-place (to PIL objects). """ template = self.template if template is None: raise ValueError(f"Template not set for adapter '{adapter_name}'. Use set_template() first.") - # For vLLM: tokenize without passing images to the processor - # This gives us the text with placeholder tokens, which vLLM will expand - messages = [dict(msg) for msg in trajectory['messages']] - - # Preprocess images for vLLM (load as PIL Images) - # vLLM expects PIL Images, not URLs - images = [] - if trajectory.get('images'): - images = template.preprocess_images(trajectory['images']) - videos = [] - if trajectory.get('videos'): - videos = template.preprocess_videos(trajectory['videos']) - - # Apply chat template without images (to get unexpanded tokens) - # We need to convert placeholders to the model's native format - for msg in messages: - content = msg.get('content', '') - if isinstance(content, str) and template.is_mm: - # Convert placeholders to standard format for tokenization - if template.image_placeholder in content: - # Split content by image placeholder and rebuild with proper format - parts = content.split(template.image_placeholder) - new_content = [] - for i, part in enumerate(parts): - if i > 0: - # Add image token structure (vLLM will expand this) - new_content.append({'type': 'image'}) - if part.strip(): - new_content.append({'type': 'text', 'text': part}) - msg['content'] = new_content if new_content else [{'type': 'text', 'text': ''}] - encoded = template.batch_encode( - [Trajectory(messages=messages)], + [trajectory], add_generation_prompt=add_generation_prompt, )[0] @@ -205,17 +170,52 @@ def encode_trajectory_for_vllm(self, result = trajectory result.update(encoded) - - # Attach preprocessed images/videos for vLLM - if images: - result['images'] = images - if videos: - result['videos'] = videos return result def apply_patch(self, patch_cls: Union[Patch, Type[Patch], str], **kwargs) -> None: apply_patch(self, patch_cls, **kwargs) + @staticmethod + def _extract_multi_modal_data(feat: Dict[str, Any]) -> Optional[Dict[str, Any]]: + """Build vLLM ``multi_modal_data`` dict from feat. + + Checks top-level 'images'/'videos' first, then falls back to + extracting PIL objects from transformers-standard message content blocks. + """ + images = feat.get('images') + videos = feat.get('videos') + + if not images and not videos: + for msg in feat.get('messages', []): + content = msg.get('content') + if not isinstance(content, list): + continue + for block in content: + if not isinstance(block, dict): + continue + btype = block.get('type') + if btype == 'image': + for key in ('image', 'url', 'path'): + if key in block and block[key] is not None: + if images is None: + images = [] + images.append(block[key]) + break + elif btype == 'video': + for key in ('video', 'url', 'path'): + if key in block and block[key] is not None: + if videos is None: + videos = [] + videos.append(block[key]) + break + + mm_data = {} + if images: + mm_data['image'] = images + if videos: + mm_data['video'] = videos + return mm_data or None + async def _sample_single( self, feat: Dict[str, Any], @@ -241,15 +241,14 @@ async def _sample_single( if hasattr(input_ids, 'tolist'): input_ids = input_ids.tolist() - images = feat.get('images') - videos = feat.get('videos') + multi_modal_data = self._extract_multi_modal_data(feat) response = await self.engine.sample( prompt_token_ids=input_ids, sampling_params=sampling_params, lora_request=lora_request, - images=images, - videos=videos, + multi_modal_data=multi_modal_data, + mm_processor_kwargs=feat.get('mm_processor_kwargs'), ) if not logprobs_only: diff --git a/src/twinkle/template/base.py b/src/twinkle/template/base.py index 167a459d..80a2b8f6 100644 --- a/src/twinkle/template/base.py +++ b/src/twinkle/template/base.py @@ -115,10 +115,7 @@ def _test_support_assistant_tokens_mask(self): def preprocess_image(self, image: ImageInput) -> 'Image.Image': if isinstance(image, dict): - if image.get('path'): - image = image['path'] - else: - image = image['bytes'] + image = image.get('bytes') or image.get('path') return load_image(image) def preprocess_video(self, video: VideoInput) -> List['Image.Image']: @@ -214,21 +211,40 @@ def _extract_reasoning_content(messages: list[Message]) -> List[Message]: trajectory['extend_message'] = result return [trajectory] + _truncatable_fields = {'input_ids', 'labels'} + def _check_max_length(self, input_feature: InputFeature) -> List[InputFeature]: if self.max_length and len(input_feature['input_ids']) > self.max_length: if self.truncation_strategy == 'raise': raise ValueError(f'An input message(length: {len(input_feature["input_ids"])} ' f'exceeds the maximum length({self.max_length})') elif self.truncation_strategy == 'left': - return [InputFeature(**{key: value[-self.max_length:] for key, value in input_feature.items()})] + return [ + InputFeature( + **{ + key: (value[-self.max_length:] if key in self._truncatable_fields else value) + for key, value in input_feature.items() + }) + ] elif self.truncation_strategy == 'right': - return [InputFeature(**{key: value[:self.max_length] for key, value in input_feature.items()})] + return [ + InputFeature( + **{ + key: (value[:self.max_length] if key in self._truncatable_fields else value) + for key, value in input_feature.items() + }) + ] else: # split result = [] total_length = len(input_feature['input_ids']) for start in range(0, total_length, self.max_length): end = min(start + self.max_length, total_length) - result.append(InputFeature(**{key: value[start:end] for key, value in input_feature.items()})) + result.append( + InputFeature( + **{ + key: (value[start:end] if key in self._truncatable_fields else value) + for key, value in input_feature.items() + })) return result else: return [input_feature] @@ -250,27 +266,56 @@ def _build_mm_messages(self, trajectory: Trajectory) -> List[Trajectory]: for message in messages: message = copy(message) content = message['content'] - msg_images = message.get('images') - msg_videos = message.get('videos') - msg_audios = message.get('audios') - if msg_images: - message['images'] = self.preprocess_images(msg_images) - assert len(message['images']) == content.count(self.image_placeholder) - if msg_videos: - message['videos'] = self.preprocess_videos(msg_videos) - assert len(message['videos']) == content.count(self.video_placeholder) - if msg_audios: - message['audios'] = self.preprocess_audios(msg_audios) - assert len(message['audios']) == content.count(self.audio_placeholder) - new_messages.append( - transfer_to_standard_message(message, self.image_placeholder, self.video_placeholder, - self.audio_placeholder, self.is_mm)) - + if isinstance(content, list): + # Transformers standard format: content is List[Dict]. + # Preprocess media references (url/path/bytes) to PIL objects in-place. + for block in content: + if not isinstance(block, dict): + continue + btype = block.get('type') + if btype == 'image': + for key in ('image', 'url', 'path'): + if key in block and block[key] is not None: + block[key] = self.preprocess_image(block[key]) + break + elif btype == 'video': + for key in ('video', 'url', 'path'): + if key in block and block[key] is not None: + block[key] = self.preprocess_video(block[key]) + break + else: + # content is str with placeholders, + # media stored in message['images']/['videos']/['audios']. + msg_images = message.get('images') + msg_videos = message.get('videos') + msg_audios = message.get('audios') + if msg_images: + message['images'] = self.preprocess_images(msg_images) + assert len(message['images']) == content.count(self.image_placeholder) + if msg_videos: + message['videos'] = self.preprocess_videos(msg_videos) + assert len(message['videos']) == content.count(self.video_placeholder) + if msg_audios: + message['audios'] = self.preprocess_audios(msg_audios) + assert len(message['audios']) == content.count(self.audio_placeholder) + message = transfer_to_standard_message(message, self.image_placeholder, self.video_placeholder, + self.audio_placeholder, self.is_mm) + new_messages.append(message) trajectory['messages'] = new_messages return [trajectory] def _apply_chat_template(self, trajectory: Trajectory, add_generation_prompt: bool = False, **kwargs): messages = [dict(message) for message in trajectory['messages']] + # Arrow serialization may pad content blocks with null keys (e.g. 'image': None + # on text-only blocks). Jinja checks `'image' in item` on dict keys, so these + # phantom keys cause wrong token counts. Strip them here. + for msg in messages: + if not isinstance(msg.get('content'), list): + continue + msg['content'] = [{ + k: v + for k, v in b.items() if v is not None + } for b in msg['content'] if isinstance(b, dict)] tools = [dict(tool) for tool in trajectory.get('tools', [])] inputs = self.processor.apply_chat_template( messages, From ed91e72a8eb2dcdb560cdfff70f13530f89d7d7f Mon Sep 17 00:00:00 2001 From: root Date: Mon, 30 Mar 2026 16:15:57 +0800 Subject: [PATCH 2/2] merge main --- src/twinkle/preprocessor/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/twinkle/preprocessor/__init__.py b/src/twinkle/preprocessor/__init__.py index 48ab6342..3fc975cb 100644 --- a/src/twinkle/preprocessor/__init__.py +++ b/src/twinkle/preprocessor/__init__.py @@ -3,4 +3,4 @@ from .dpo import EmojiDPOProcessor from .llm import (AlpacaProcessor, CompetitionMathGRPOProcessor, CompetitionMathProcessor, CountdownProcessor, GSM8KProcessor, SelfCognitionProcessor) -from .mm import CLEVRProcessor +from .mm import CLEVRProcessor, VisionQAProcessor