-
Notifications
You must be signed in to change notification settings - Fork 55
Add LLaMA-Factory support #38
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| @@ -0,0 +1,273 @@ | ||||||||||||||
| diff --git a/src/llamafactory/chat/hf_engine.py b/src/llamafactory/chat/hf_engine.py | ||||||||||||||
| index adaaaa87..bd16f78f 100644 | ||||||||||||||
| --- a/src/llamafactory/chat/hf_engine.py | ||||||||||||||
| +++ b/src/llamafactory/chat/hf_engine.py | ||||||||||||||
| @@ -204,6 +204,9 @@ class HuggingfaceEngine(BaseEngine): | ||||||||||||||
| gen_kwargs["audio_feature_lens"] = mm_inputs["audio_feature_lens"] | ||||||||||||||
|
|
||||||||||||||
| gen_kwargs.pop("image_sizes", None) | ||||||||||||||
| + if getattr(model.config, "model_type", None) in ["bailingmm"]: | ||||||||||||||
| + gen_kwargs["input_ids"] = inputs | ||||||||||||||
| + del gen_kwargs["inputs"] | ||||||||||||||
|
|
||||||||||||||
| return gen_kwargs, prompt_length | ||||||||||||||
|
|
||||||||||||||
| diff --git a/src/llamafactory/data/mm_plugin.py b/src/llamafactory/data/mm_plugin.py | ||||||||||||||
| index 3ac9c307..12c66079 100644 | ||||||||||||||
| --- a/src/llamafactory/data/mm_plugin.py | ||||||||||||||
| +++ b/src/llamafactory/data/mm_plugin.py | ||||||||||||||
| @@ -892,6 +892,187 @@ class LlavaNextVideoPlugin(BasePlugin): | ||||||||||||||
| return messages | ||||||||||||||
|
|
||||||||||||||
|
|
||||||||||||||
| +@dataclass | ||||||||||||||
| +class MingOmniPlugin(BasePlugin): | ||||||||||||||
| + def _validate_input( | ||||||||||||||
| + self, | ||||||||||||||
| + processor: Optional["MMProcessor"], | ||||||||||||||
| + images: list["ImageInput"], | ||||||||||||||
| + videos: list["VideoInput"], | ||||||||||||||
| + audios: list["AudioInput"], | ||||||||||||||
| + ) -> None: | ||||||||||||||
| + r"""Validate if this model accepts the input modalities.""" | ||||||||||||||
| + image_processor: BaseImageProcessor = getattr(processor, "image_processor", None) | ||||||||||||||
| + video_processor: BaseImageProcessor = getattr(processor, "image_processor", None) | ||||||||||||||
| + audio_processor = getattr(processor, "audio_processor", None) | ||||||||||||||
| + if len(images) != 0 and self.image_token is None: | ||||||||||||||
| + raise ValueError( | ||||||||||||||
| + "This model does not support image input. Please check whether the correct `template` is used." | ||||||||||||||
| + ) | ||||||||||||||
| + | ||||||||||||||
| + if len(videos) != 0 and self.video_token is None: | ||||||||||||||
| + raise ValueError( | ||||||||||||||
| + "This model does not support video input. Please check whether the correct `template` is used." | ||||||||||||||
| + ) | ||||||||||||||
| + | ||||||||||||||
| + if len(audios) != 0 and self.audio_token is None: | ||||||||||||||
| + raise ValueError( | ||||||||||||||
| + "This model does not support audio input. Please check whether the correct `template` is used." | ||||||||||||||
| + ) | ||||||||||||||
| + | ||||||||||||||
| + if self.image_token is not None and processor is None: | ||||||||||||||
| + raise ValueError("Processor was not found, please check and update your model file.") | ||||||||||||||
| + | ||||||||||||||
| + if self.image_token is not None and image_processor is None: | ||||||||||||||
| + raise ValueError("Image processor was not found, please check and update your model file.") | ||||||||||||||
| + | ||||||||||||||
| + if self.video_token is not None and video_processor is None: | ||||||||||||||
| + raise ValueError("Video processor was not found, please check and update your model file.") | ||||||||||||||
| + | ||||||||||||||
| + if self.audio_token is not None and audio_processor is None: | ||||||||||||||
| + raise ValueError("Audio feature extractor was not found, please check and update your model file.") | ||||||||||||||
| + | ||||||||||||||
| + @override | ||||||||||||||
| + def _get_mm_inputs( | ||||||||||||||
| + self, | ||||||||||||||
| + images: list["ImageInput"], | ||||||||||||||
| + videos: list["VideoInput"], | ||||||||||||||
| + audios: list["AudioInput"], | ||||||||||||||
| + processor: "MMProcessor", | ||||||||||||||
| + ) -> dict[str, "torch.Tensor"]: | ||||||||||||||
| + image_processor: BaseImageProcessor = getattr(processor, "image_processor", None) | ||||||||||||||
| + audio_processor = getattr(processor, "audio_processor", None) | ||||||||||||||
| + mm_inputs = {} | ||||||||||||||
| + if len(images) != 0: | ||||||||||||||
| + images = self._regularize_images( | ||||||||||||||
| + images, | ||||||||||||||
| + image_max_pixels=getattr(image_processor, "max_pixels", 2007040), | ||||||||||||||
| + image_min_pixels=getattr(image_processor, "min_pixels", 78400), | ||||||||||||||
| + )["images"] | ||||||||||||||
| + mm_inputs.update(image_processor(images=images, videos=None, return_tensors="pt")) | ||||||||||||||
| + | ||||||||||||||
| + if len(videos) != 0: | ||||||||||||||
| + videos = self._regularize_videos( | ||||||||||||||
| + videos, | ||||||||||||||
| + image_max_pixels=getattr(image_processor, "max_pixels_video", 768 * 28 * 28), | ||||||||||||||
| + image_min_pixels=getattr(image_processor, "min_pixels_video", 128 * 28 * 28), | ||||||||||||||
| + video_fps=getattr(image_processor, "video_fps", 2.0), | ||||||||||||||
| + video_maxlen=getattr(image_processor, "video_maxlen", 128), | ||||||||||||||
| + )["videos"] | ||||||||||||||
| + # Ming can only deal with even frames. | ||||||||||||||
| + videos = [video[:-1] if len(video) % 2 else video for video in videos] | ||||||||||||||
| + mm_inputs.update(image_processor(images=None, videos=videos, do_resize=True, return_tensors="pt")) | ||||||||||||||
| + | ||||||||||||||
| + # if len(audios) != 0: | ||||||||||||||
| + # sampling_rate = 16000 | ||||||||||||||
| + # audios = self._regularize_audios(audios, sampling_rate=sampling_rate)["audios"] | ||||||||||||||
| + # audios = [(torch.tensor(audio), sampling_rate)for audio in audios] | ||||||||||||||
|
||||||||||||||
| + # mm_inputs.update( | ||||||||||||||
| + # audio_processor( | ||||||||||||||
| + # audios, | ||||||||||||||
| + # padding="max_length", | ||||||||||||||
| + # use_whisper_encoder=False, | ||||||||||||||
| + # return_tensors="pt", | ||||||||||||||
| + # ) | ||||||||||||||
| + # ) | ||||||||||||||
| + | ||||||||||||||
| + return mm_inputs | ||||||||||||||
| + | ||||||||||||||
| + @override | ||||||||||||||
| + def process_messages( | ||||||||||||||
| + self, | ||||||||||||||
| + messages: list[dict[str, str]], | ||||||||||||||
| + images: list["ImageInput"], | ||||||||||||||
| + videos: list["VideoInput"], | ||||||||||||||
| + audios: list["AudioInput"], | ||||||||||||||
| + processor: Optional["MMProcessor"], | ||||||||||||||
| + ) -> list[dict[str, str]]: | ||||||||||||||
| + self._validate_input(processor, images, videos, audios) | ||||||||||||||
| + self._validate_messages(messages, images, videos, audios) | ||||||||||||||
| + messages = deepcopy(messages) | ||||||||||||||
| + image_processor: BaseImageProcessor = getattr(processor, "image_processor") | ||||||||||||||
| + image_inputs, video_inputs, audio_inputs = {}, {}, {} | ||||||||||||||
| + | ||||||||||||||
| + if len(images): | ||||||||||||||
| + image_inputs = self._get_mm_inputs(images, [], [], processor) | ||||||||||||||
| + image_grid_thw = image_inputs["image_grid_thw"] | ||||||||||||||
| + | ||||||||||||||
| + if len(videos): | ||||||||||||||
| + # assert len(videos) <= 1, "Video count must be at most 1!" | ||||||||||||||
|
||||||||||||||
| + # assert len(videos) <= 1, "Video count must be at most 1!" | |
| + if len(videos) > 1: | |
| + raise ValueError("Video count must be at most 1!") |
Copilot
AI
Sep 29, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Another block of commented-out code that should be removed or properly implemented. This creates inconsistency with the audio processing logic above.
| + # if len(audios): | |
| + # audio_inputs = self._get_mm_inputs([], [], audios, processor) | |
| + # audio_feats_lengths = audio_inputs["encoder_feats_lengths"] | |
| + if len(audios): | |
| + audio_inputs = self._get_mm_inputs([], [], audios, processor) | |
| + audio_feats_lengths = audio_inputs["encoder_feats_lengths"] |
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
| @@ -0,0 +1,46 @@ | ||||||
| ### model | ||||||
| model_name_or_path: inclusionAI/Ming-Lite-Omni | ||||||
| trust_remote_code: true | ||||||
|
|
||||||
| ### method | ||||||
| stage: sft | ||||||
| do_train: true | ||||||
| finetuning_type: lora | ||||||
| lora_rank: 8 | ||||||
| lora_target: query_key_value,dense | ||||||
|
|
||||||
| ### dataset | ||||||
| dataset: mllm_video_demo | ||||||
| template: ming | ||||||
| cutoff_len: 4096 | ||||||
| max_samples: 1000 | ||||||
| overwrite_cache: true | ||||||
| preprocessing_num_workers: 1 | ||||||
| dataloader_num_workers: 4 | ||||||
|
|
||||||
| ### output | ||||||
| output_dir: saves/ming-lite-omni/lora/sft | ||||||
| logging_steps: 10 | ||||||
| save_steps: 500 | ||||||
| plot_loss: true | ||||||
| overwrite_output_dir: true | ||||||
| save_only_model: false | ||||||
| report_to: none # choices: [none, wandb, tensorboard, swanlab, mlflow] | ||||||
|
||||||
| report_to: none # choices: [none, wandb, tensorboard, swanlab, mlflow] | |
| report_to: none # choices: [none, wandb, tensorboard, mlflow] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Typo in directory name - should be 'inclusionAI' to match the actual organization name referenced elsewhere in the documentation.