diff --git a/configs/quantization/video_gen/wan_i2v/awq_w_a.yaml b/configs/quantization/video_gen/wan_i2v/awq_w_a.yaml index 262d68520..680fab43b 100755 --- a/configs/quantization/video_gen/wan_i2v/awq_w_a.yaml +++ b/configs/quantization/video_gen/wan_i2v/awq_w_a.yaml @@ -46,4 +46,4 @@ quant: clip_sym: True save: save_lightx2v: True - save_path: /path/to/x2v/ \ No newline at end of file + save_path: /path/to/x2v/ diff --git a/configs/quantization/video_gen/wan_t2v/awq_w_a.yaml b/configs/quantization/video_gen/wan_t2v/awq_w_a.yaml index 59e35dd4e..14d05479d 100755 --- a/configs/quantization/video_gen/wan_t2v/awq_w_a.yaml +++ b/configs/quantization/video_gen/wan_t2v/awq_w_a.yaml @@ -46,4 +46,4 @@ quant: clip_sym: True save: save_lightx2v: True - save_path: /path/to/x2v/ \ No newline at end of file + save_path: /path/to/x2v/ diff --git a/configs/quantization/video_gen/wan_t2v/rtn_w_a.yaml b/configs/quantization/video_gen/wan_t2v/rtn_w_a.yaml index 844b62214..b6a53b0e0 100644 --- a/configs/quantization/video_gen/wan_t2v/rtn_w_a.yaml +++ b/configs/quantization/video_gen/wan_t2v/rtn_w_a.yaml @@ -29,4 +29,4 @@ quant: granularity: per_token save: save_lightx2v: True - save_path: /path/to/x2v/ \ No newline at end of file + save_path: /path/to/x2v/ diff --git a/configs/quantization/video_gen/wan_t2v/smoothquant_w_a.yaml b/configs/quantization/video_gen/wan_t2v/smoothquant_w_a.yaml index 122d31f79..7d65f31fc 100755 --- a/configs/quantization/video_gen/wan_t2v/smoothquant_w_a.yaml +++ b/configs/quantization/video_gen/wan_t2v/smoothquant_w_a.yaml @@ -42,4 +42,4 @@ quant: alpha: 0.7 save: save_lightx2v: True - save_path: /path/to/x2v/ \ No newline at end of file + save_path: /path/to/x2v/ diff --git a/llmc/compression/quantization/base_blockwise_quantization.py b/llmc/compression/quantization/base_blockwise_quantization.py index 2df4f8c93..a09fc5f8d 100755 --- a/llmc/compression/quantization/base_blockwise_quantization.py +++ b/llmc/compression/quantization/base_blockwise_quantization.py @@ -35,11 +35,7 @@ _TRANSFORMERS_LN_TYPES_, EffcientFakeQuantLinear, FakeQuantLinear, LlmcActFn, OriginFloatLinear, RotateLinear) -from .quant import ( - FloatQuantizer, - IntegerQuantizer, - Weight48IntegerQuantizer, -) +from .quant import FloatQuantizer, IntegerQuantizer, Weight48IntegerQuantizer class BaseBlockwiseQuantization(BlockwiseOpt): diff --git a/llmc/compression/quantization/module_utils.py b/llmc/compression/quantization/module_utils.py index 1c0e6e455..80918e3f1 100755 --- a/llmc/compression/quantization/module_utils.py +++ b/llmc/compression/quantization/module_utils.py @@ -446,6 +446,14 @@ def __repr__(self): return 'LlmcQwen2RMSNorm()' +class LlmcIndustrialCoderRMSNorm(LlmcLlamaRMSNorm): + def __init__(self, weight, eps=1e-6): + super().__init__(weight, eps) + + def __repr__(self): + return 'LlmcIndustrialCoderRMSNorm()' + + class LlmcMixtralRMSNorm(LlmcLlamaRMSNorm): def __init__(self, weight, eps=1e-6): super().__init__(weight, eps) @@ -1187,6 +1195,7 @@ def __repr__(self): 'Mixtral': LlmcMixtralRMSNorm, 'Interlm2': LlmcInternLM2RMSNorm, 'Qwen2': LlmcQwen2RMSNorm, + 'IndustrialCoder': LlmcIndustrialCoderRMSNorm, 'Gemma2': LlmcGemma2RMSNorm, 'MiniCPM': LlmcMiniCPMRMSNorm, 'Starcoder': LlmcLayerNorm, diff --git a/llmc/data/dataset/base_dataset.py b/llmc/data/dataset/base_dataset.py index bd773d920..85260e83d 100755 --- a/llmc/data/dataset/base_dataset.py +++ b/llmc/data/dataset/base_dataset.py @@ -3,11 +3,12 @@ from abc import ABCMeta import torch -from datasets import load_dataset, load_from_disk from loguru import logger from PIL import Image from torch.nn import functional as F +from datasets import load_dataset, load_from_disk + from .specified_preproc import PREPROC_REGISTRY @@ -172,9 +173,10 @@ def get_batch_process(self, samples): return calib_model_inputs def get_calib_dataset(self): - samples = self.calib_dataset[ - int(os.environ['RANK'])::int(os.environ['WORLD_SIZE']) - ] + samples = self.calib_dataset.shard( + num_shards=int(os.environ['WORLD_SIZE']), + index=int(os.environ['RANK']) + ) logger.info(f'len(samples) rank : {len(samples)}') calib_model_inputs = self.get_calib_model_inputs(samples) diff --git a/llmc/eval/eval_base.py b/llmc/eval/eval_base.py index 60a60589e..098c9bb8f 100755 --- a/llmc/eval/eval_base.py +++ b/llmc/eval/eval_base.py @@ -5,10 +5,11 @@ import torch import torch.nn as nn -from datasets import load_dataset, load_from_disk from human_eval.data import read_problems from loguru import logger +from datasets import load_dataset, load_from_disk + class BaseEval: def __init__(self, model, config): diff --git a/llmc/eval/eval_ppl.py b/llmc/eval/eval_ppl.py index d598218c5..bb41329f1 100644 --- a/llmc/eval/eval_ppl.py +++ b/llmc/eval/eval_ppl.py @@ -3,10 +3,11 @@ import torch import torch.nn as nn -from datasets import load_dataset, load_from_disk from loguru import logger from tqdm import tqdm +from datasets import load_dataset, load_from_disk + from .eval_base import BaseEval diff --git a/llmc/models/__init__.py b/llmc/models/__init__.py index 7351995df..4f196c6ea 100755 --- a/llmc/models/__init__.py +++ b/llmc/models/__init__.py @@ -5,6 +5,7 @@ from .falcon import Falcon from .gemma2 import Gemma2 from .glm4v import GLM4V +from .industrialcoder import IndustrialCoder from .internlm2 import InternLM2 from .internomni import InternOmni from .internvl2 import InternVL2 @@ -35,6 +36,6 @@ from .videollava import VideoLLaVA from .vila import Vila from .vit import Vit +from .wan2_2_t2v import Wan2T2V from .wan_i2v import WanI2V from .wan_t2v import WanT2V -from .wan2_2_t2v import Wan2T2V diff --git a/llmc/models/base_model.py b/llmc/models/base_model.py index 25393a871..315a749b5 100755 --- a/llmc/models/base_model.py +++ b/llmc/models/base_model.py @@ -129,7 +129,7 @@ def build_tokenizer(self): if self.tokenizer.pad_token is None: self.tokenizer.pad_token = self.tokenizer.eos_token else: - self.tokenizer = None + self.tokenizer = None def get_tokenizer(self): return self.tokenizer diff --git a/llmc/models/industrialcoder.py b/llmc/models/industrialcoder.py new file mode 100644 index 000000000..302e8953a --- /dev/null +++ b/llmc/models/industrialcoder.py @@ -0,0 +1,126 @@ +"""IndustrialCoder (IQuestCoder) model adapter for LLMC quantization. + +Model structure follows IQuestCoderForCausalLM / IQuestCoderModel: + - model.model.embed_tokens, model.model.layers, model.model.norm, model.model.rotary_emb + - model.lm_head + - Each layer: input_layernorm, self_attn (q_proj, k_proj, v_proj, o_proj), + post_attention_layernorm, mlp (gate_proj, up_proj, down_proj) + +Layout is the same as Qwen2-style decoders; this module provides a dedicated +adapter so IndustrialCoder is supported as its own model type, not as Qwen2. +""" + +from importlib.metadata import version + +import packaging + +from llmc.utils.registry_factory import MODEL_REGISTRY + +from .base_model import BaseModel + + +@MODEL_REGISTRY +class IndustrialCoder(BaseModel): + """IndustrialCoder (IQuestCoder) standalone adapter for blockwise + quantization.""" + + def __init__(self, config, device_map=None, use_cache=False): + super().__init__(config, device_map, use_cache) + + def find_blocks(self): + # IQuestCoderForCausalLM.model -> IQuestCoderModel with .layers + self.blocks = self.model.model.layers + + def find_embed_layers(self): + base = self.model.model + self.embed_tokens = base.embed_tokens + if hasattr(base, 'rotary_emb') and ( + packaging.version.parse(version('transformers')) >= packaging.version.parse('4.45.0') + ): + self.rotary_emb = base.rotary_emb + + def find_block_name(self): + self.block_name_prefix = 'model.layers' + + def get_embed_layers(self): + return [self.embed_tokens] + + def get_attn_in_block(self, block): + return {'self_attn': block.self_attn} + + def get_attention_rotary_layers(self): + if packaging.version.parse(version('transformers')) >= packaging.version.parse('4.45.0'): + if hasattr(self, 'rotary_emb') and self.rotary_emb is not None: + return [self.rotary_emb] + return [] + return [] + + def get_head_layers(self): + return [self.model.lm_head] + + def get_pre_head_layernorm_layers(self): + return [self.model.model.norm] + + def get_layers_except_blocks(self): + if packaging.version.parse(version('transformers')) >= packaging.version.parse('4.45.0'): + rotary = [] + if hasattr(self, 'rotary_emb') and self.rotary_emb is not None: + rotary = [self.rotary_emb] + return [self.embed_tokens] + rotary + [self.model.model.norm, self.model.lm_head] + return [self.embed_tokens, self.model.model.norm, self.model.lm_head] + + def skip_layer_name(self): + return ['lm_head'] + + def has_bias(self): + # IQuestCoder config: attention_bias, mlp_bias (often False) + cfg = self.model_config + return getattr(cfg, 'attention_bias', False) or getattr(cfg, 'mlp_bias', False) + + def get_layernorms_in_block(self, block): + return { + 'input_layernorm': block.input_layernorm, + 'post_attention_layernorm': block.post_attention_layernorm, + } + + def get_subsets_in_block(self, block): + # Same layout as Qwen2 / IQuestCoderDecoderLayer + return [ + { + 'layers': { + 'self_attn.q_proj': block.self_attn.q_proj, + 'self_attn.k_proj': block.self_attn.k_proj, + 'self_attn.v_proj': block.self_attn.v_proj, + }, + 'prev_op': [block.input_layernorm], + 'input': ['self_attn.q_proj'], + 'inspect': block.self_attn, + 'has_kwargs': True, + }, + { + 'layers': {'self_attn.o_proj': block.self_attn.o_proj}, + 'prev_op': [block.self_attn.v_proj], + 'input': ['self_attn.o_proj'], + 'inspect': block.self_attn.o_proj, + 'has_kwargs': False, + }, + { + 'layers': { + 'mlp.gate_proj': block.mlp.gate_proj, + 'mlp.up_proj': block.mlp.up_proj, + }, + 'prev_op': [block.post_attention_layernorm], + 'input': ['mlp.gate_proj'], + 'inspect': block.mlp, + 'has_kwargs': False, + 'is_mlp': True, + }, + { + 'layers': {'mlp.down_proj': block.mlp.down_proj}, + 'prev_op': [block.mlp.up_proj], + 'input': ['mlp.down_proj'], + 'inspect': block.mlp.down_proj, + 'has_kwargs': False, + 'is_mlp': True, + }, + ] diff --git a/llmc/models/wan2_2_t2v.py b/llmc/models/wan2_2_t2v.py index c3db088da..a19ff0b08 100755 --- a/llmc/models/wan2_2_t2v.py +++ b/llmc/models/wan2_2_t2v.py @@ -1,5 +1,5 @@ -import gc import copy +import gc import inspect import os import shutil @@ -19,7 +19,8 @@ class WanOfficialPipelineAdapter: - """Adapter that exposes Wan-Video/Wan2.2 official t2v runtime as a Pipeline-like interface.""" + """Adapter that exposes Wan-Video/Wan2.2 official t2v runtime as a + Pipeline-like interface.""" def __init__( self, @@ -116,7 +117,8 @@ def __call__( @MODEL_REGISTRY class Wan2T2V(BaseModel): - """Wan2.2-T2V with MoE: two experts (high-noise + low-noise), same block structure as Wan2.1.""" + """Wan2.2-T2V with MoE: two experts (high-noise + low-noise), same block + structure as Wan2.1.""" def __init__(self, config, device_map=None, use_cache=False): super().__init__(config, device_map, use_cache) @@ -200,11 +202,13 @@ def _import_impl(): return _import_impl() except Exception as e2: logger.warning( - f'Failed to import official Wan2.2 from wan2_repo_path={repo_path}: {e2}' + 'Failed to import official Wan2.2 from ' + f'wan2_repo_path={repo_path}: {e2}' ) logger.warning( 'Failed to import official Wan2.2 runtime (wan package). ' - 'Diffusers fallback depends on model.allow_diffusers_fallback/model.force_diffusers. ' + 'Diffusers fallback depends on model.allow_diffusers_fallback/' + 'model.force_diffusers. ' f'import_error={e}' ) return None, None @@ -257,7 +261,8 @@ def _try_build_official_wan_pipeline(self): self.pipeline_source = 'wan_official' self.use_official_wan = True logger.info( - f'Loaded Wan2.2 via official Wan runtime from native checkpoint: {normalized_model_path}' + 'Loaded Wan2.2 via official Wan runtime from native checkpoint: ' + f'{normalized_model_path}' ) return True @@ -360,7 +365,10 @@ def build_model(self): new_block = LlmcWanTransformerBlock.new(block) self.Pipeline.transformer_2.blocks[block_idx] = new_block self.num_transformer_blocks = len(self.Pipeline.transformer.blocks) - self.blocks = list(self.Pipeline.transformer.blocks) + list(self.Pipeline.transformer_2.blocks) + self.blocks = ( + list(self.Pipeline.transformer.blocks) + + list(self.Pipeline.transformer_2.blocks) + ) logger.info( 'Wan2.2 MoE: both experts wrapped (high-noise + low-noise, 80 blocks total).' ) @@ -456,7 +464,10 @@ def forward(self, *args, **kwargs): first_block_input[self.expert_name]['kwargs'].append( {k: self._to_cpu(v) for k, v in capture_kwargs.items()} ) - if all(len(first_block_input[name]['data']) >= sample_steps for name in first_block_input): + if all( + len(first_block_input[name]['data']) >= sample_steps + for name in first_block_input + ): raise ValueError return self.module(*args, **kwargs) @@ -488,10 +499,13 @@ def forward(self, *args, **kwargs): self.Pipeline.transformer.blocks[0] = self.Pipeline.transformer.blocks[0].module if first_block_2 is not None: - self.Pipeline.transformer_2.blocks[0] = self.Pipeline.transformer_2.blocks[0].module + transformer_2 = self.Pipeline.transformer_2 + transformer_2.blocks[0] = transformer_2.blocks[0].module self.Pipeline.to('cpu') - assert len(first_block_input['transformer']['data']) > 0, 'Catch transformer input data failed.' + assert len(first_block_input['transformer']['data']) > 0, ( + 'Catch transformer input data failed.' + ) if hasattr(self.Pipeline, 'transformer_2') and self.Pipeline.transformer_2 is not None: assert len(first_block_input['transformer_2']['data']) > 0, \ 'Catch transformer_2 input data failed.' @@ -623,7 +637,8 @@ def get_layers_except_blocks(self): @staticmethod def copy_native_checkpoint(src, dst): - """Copy full Wan2.2 native checkpoint tree before overwriting expert safetensors.""" + """Copy full Wan2.2 native checkpoint tree before overwriting expert + safetensors.""" if not isinstance(src, str) or not os.path.isdir(src): raise RuntimeError( 'Wan2.2 official save expects a local native checkpoint directory, ' @@ -641,7 +656,8 @@ def copy_native_checkpoint(src, dst): @staticmethod def validate_native_save_structure(save_path, source_path=None): - """Verify saved directory has Wan2.2 native layout (experts + copied non-expert assets).""" + """Verify saved directory has Wan2.2 native layout (experts + copied + non-expert assets).""" if not os.path.isdir(save_path): raise RuntimeError(f'Wan2.2 saved path is not a directory: {save_path}') @@ -705,11 +721,12 @@ def save_wan2_2_pretrained(self, path): self.validate_native_save_structure(path, source_path=src) return - # Copy the full original pipeline (VAE, text encoder, tokenizer, scheduler, etc.) - # so that non-quantized components are preserved. + # Copy the full original pipeline (VAE, text encoder, tokenizer, + # scheduler, etc.) so that non-quantized components are preserved. src = getattr(self, 'pipeline_model_path', self.model_path) copied_from_source = False - if isinstance(src, str) and os.path.isdir(src) and os.path.abspath(src) != os.path.abspath(path): + same_path = os.path.abspath(src) == os.path.abspath(path) + if isinstance(src, str) and os.path.isdir(src) and not same_path: if os.path.exists(path): shutil.rmtree(path) shutil.copytree(src, path) diff --git a/llmc/models/wan_t2v.py b/llmc/models/wan_t2v.py index 59696686d..885bccda3 100755 --- a/llmc/models/wan_t2v.py +++ b/llmc/models/wan_t2v.py @@ -162,4 +162,4 @@ def get_layers_except_blocks(self): pass def skip_layer_name(self): - pass \ No newline at end of file + pass diff --git a/tools/download_calib_dataset.py b/tools/download_calib_dataset.py index 37ce76bab..31fe14772 100644 --- a/tools/download_calib_dataset.py +++ b/tools/download_calib_dataset.py @@ -6,9 +6,10 @@ import argparse import os -from datasets import load_dataset from loguru import logger +from datasets import load_dataset + def download(calib_dataset_name, path): if 'pileval' in calib_dataset_name: diff --git a/tools/download_eval_dataset.py b/tools/download_eval_dataset.py index 7eddd8bd0..12f1f2a6a 100644 --- a/tools/download_eval_dataset.py +++ b/tools/download_eval_dataset.py @@ -6,9 +6,10 @@ import argparse import os -from datasets import load_dataset from loguru import logger +from datasets import load_dataset + def download(calib_dataset_name, path): if 'c4' in calib_dataset_name: