diff --git a/autosmoothquant/examples/ppl_eval.py b/autosmoothquant/examples/ppl_eval.py new file mode 100644 index 0000000..031e9ac --- /dev/null +++ b/autosmoothquant/examples/ppl_eval.py @@ -0,0 +1,76 @@ +import torch +import torch.nn as nn +from transformers import AutoTokenizer, AutoModelForCausalLM + +from models.phi2 import Int8PhiForCausalLM +from models.llama import Int8LlamaForCausalLM +from models.qwen2 import Int8Qwen2ForCausalLM +import tqdm +import os +from datasets import load_dataset +import argparse +from utils import get_config, get_model_architecture, build_model_and_tokenizer, parse_quant_config +from transformers.models.phi.modeling_phi import PhiForCausalLM +from transformers.models.llama.modeling_llama import LlamaForCausalLM +from transformers.models.qwen2.modeling_qwen2 import Qwen2ForCausalLM +parser = argparse.ArgumentParser() +parser.add_argument("--alpha", type=float, default=0.5) +parser.add_argument("--model_path", type=str, default="quantized_model/qwen2/qwen2-smoothquant") + + +args = parser.parse_args() +alpha = args.alpha +model_path = args.model_path + + +class Evaluator: + def __init__(self, dataset, tokenizer, device, n_samples=40): + self.dataset = dataset + self.tokenizer = tokenizer + self.device = device + + self.dataset = tokenizer( + "\n\n".join(dataset["text"]), return_tensors="pt" + ).input_ids.to(device) + + self.n_samples = n_samples + + @torch.no_grad() + def evaluate(self, model): + model.eval() + nlls = [] + n_samples = self.n_samples if self.n_samples else self.dataset.size(1) // 2048 + for i in tqdm.tqdm(range(n_samples), desc="Evaluating..."): + batch = self.dataset[:, (i * 2048) : ((i + 1) * 2048)].to(model.device) + with torch.no_grad(): + lm_logits = model(batch).logits + shift_logits = lm_logits[:, :-1, :].contiguous().float() + shift_labels = self.dataset[:, (i * 2048) : ((i + 1) * 2048)][:, 1:] + loss_fct = nn.CrossEntropyLoss() + loss = loss_fct( + shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1) + ) + neg_log_likelihood = loss.float() * 2048 + nlls.append(neg_log_likelihood) + + return torch.exp(torch.stack(nlls).sum() / (n_samples * 2048)) + + +tokenizer = AutoTokenizer.from_pretrained(model_path) +dataset = load_dataset("wikitext", "wikitext-2-raw-v1", split="test") +evaluator = Evaluator(dataset, tokenizer, "cuda") +config_path = os.path.join(args.model_path, "quant_config.json") +quant_config = parse_quant_config(config_path) + +model = Int8Qwen2ForCausalLM.from_pretrained(args.model_path, quant_config, + device_map="sequential") +# model = Int8PhiForCausalLM.from_pretrained(args.model_path, quant_config, +# device_map="sequential") +# model = Int8PhiForCausalLM.from_pretrained(args.model_path, quant_config, attn_implementation="eager", +# device_map="sequential") +# model = PhiForCausalLM.from_pretrained( +# args.model_path, device_map="auto", torch_dtype=torch.float16) +# model = Qwen2ForCausalLM.from_pretrained( +# args.model_path, device_map="auto", torch_dtype=torch.float16) +ppl = evaluator.evaluate(model) +print(f"Perplexity: {ppl}") diff --git a/autosmoothquant/examples/test_model.py b/autosmoothquant/examples/test_model.py index 82996a8..8b36ad6 100644 --- a/autosmoothquant/examples/test_model.py +++ b/autosmoothquant/examples/test_model.py @@ -3,7 +3,9 @@ import argparse import json -from autosmoothquant.models import Int8LlamaForCausalLM, Int8OPTForCausalLM, Int8BaichuanForCausalLM, Int8MixtralForCausalLM +from autosmoothquant.models import (Int8LlamaForCausalLM, Int8OPTForCausalLM, + Int8BaichuanForCausalLM, Int8MixtralForCausalLM, + Int8PhiForCausalLM,Int8Qwen2ForCausalLM) from autosmoothquant.utils import parse_quant_config from transformers import AutoTokenizer @@ -39,6 +41,12 @@ def main(): model = Int8OPTForCausalLM.from_pretrained(args.model_path, quant_config, attn_implementation="eager", device_map="sequential") elif args.model_class == "mixtral": model = Int8MixtralForCausalLM.from_pretrained(args.model_path, quant_config, attn_implementation="eager", device_map="sequential") + elif args.model_class == "phi2": + model = Int8PhiForCausalLM.from_pretrained(args.model_path, quant_config, attn_implementation="eager", + device_map="sequential") + elif args.model_class == "qwen2": + model = Int8Qwen2ForCausalLM.from_pretrained(args.model_path, quant_config, attn_implementation="eager", + device_map="sequential") else: raise ValueError( f"Model type {args.model_class} are not supported for now.") diff --git a/autosmoothquant/models/__init__.py b/autosmoothquant/models/__init__.py index 1470eed..c8855b6 100644 --- a/autosmoothquant/models/__init__.py +++ b/autosmoothquant/models/__init__.py @@ -2,6 +2,8 @@ from .llama import Int8LlamaForCausalLM from .mixtral import Int8MixtralForCausalLM from .opt import Int8OPTForCausalLM +from .phi2 import Int8PhiForCausalLM +from .qwen2 import Int8Qwen2ForCausalLM from autosmoothquant.thirdparty.baichuan.configuration_baichuan import BaichuanConfig _MODEL_REGISTRY = { @@ -9,7 +11,9 @@ "LLaMAForCausalLM": Int8LlamaForCausalLM, "BaichuanForCausalLM": Int8BaichuanForCausalLM, "OPTForCausalLM": Int8OPTForCausalLM, - "MixtralForCausalLM": Int8MixtralForCausalLM + "MixtralForCausalLM": Int8MixtralForCausalLM, + "PhiForCausalLM": Int8PhiForCausalLM, + "Qwen2ForCausalLM": Int8Qwen2ForCausalLM } _MODEL_TYPE = { @@ -17,7 +21,9 @@ "LLaMAForCausalLM": "llama", "BaichuanForCausalLM": "baichuan", "OPTForCausalLM": "transformers", - "MixtralForCausalLM": "mixtral" + "MixtralForCausalLM": "mixtral", + "PhiForCausalLM": "phi", + "Qwen2ForCausalLM": "qwen2" } _CONFIG_REGISTRY = { diff --git a/autosmoothquant/models/phi2.py b/autosmoothquant/models/phi2.py new file mode 100644 index 0000000..4f05a02 --- /dev/null +++ b/autosmoothquant/models/phi2.py @@ -0,0 +1,255 @@ +""" PyTorch Phi model.""" + +from typing import List, Optional, Tuple, Union +import torch +import torch.nn.functional as F +import torch.utils.checkpoint +from torch import nn +from transformers.models.phi.modeling_phi import ( + PhiMLP, + PhiAttention, + PhiDecoderLayer, + PhiPreTrainedModel, + PhiModel, + PhiForCausalLM, +) + +from transformers.activations import ACT2FN +from layers.nn.linear import W8A8BFP32OFP32LinearWithQuantScale, W8A8BFP32OFP32Linear +from transformers.utils import logging +from transformers.models.phi.configuration_phi import PhiConfig + +logger = logging.get_logger(__name__) +class Int8PhiLayerNorm(nn.LayerNorm): + @staticmethod + def from_float(module: nn.LayerNorm, output_scale: float): + assert module.normalized_shape[0] == module.weight.numel() + assert module.normalized_shape[0] == module.bias.numel() + q_module = Int8PhiLayerNorm(module.normalized_shape[0], module.eps) + q_module.weight = nn.Parameter(module.weight / output_scale) + q_module.bias = nn.Parameter(module.bias / output_scale) + return q_module +class Int8PhiMLP(nn.Module): + def __init__(self, config, quant_config: dict[str, str]): + super().__init__() + self.config = config + self.activation_fn = ACT2FN[config.hidden_act] + self.fc1_quant_type = quant_config["fc1"] + self.fc2_quant_type = quant_config["fc2"] + self.fc1 = W8A8BFP32OFP32Linear(config.hidden_size, config.intermediate_size, act_quant=self.fc1_quant_type) + self.fc2 = W8A8BFP32OFP32LinearWithQuantScale(config.intermediate_size, config.hidden_size,act_quant=self.fc2_quant_type) + + forward = PhiMLP.forward + + @staticmethod + @torch.no_grad() + def from_float(module: PhiMLP, + config: PhiConfig, + quant_config: dict[str, str], + fc1_input_scale: float, + fc2_input_scale: float): + int8_module = Int8PhiMLP(config, quant_config) + int8_module.fc1 = W8A8BFP32OFP32Linear.from_float( + module.fc1, fc1_input_scale, act_quant=int8_module.fc1_quant_type) + int8_module.fc2 = W8A8BFP32OFP32LinearWithQuantScale.from_float( + module.fc2, fc2_input_scale, act_quant=int8_module.fc2_quant_type) + return int8_module + +class Int8PhiAttention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__(self, config: PhiConfig, quant_config: dict[str, str], layer_idx: Optional[int] = None): + super().__init__() + self.config = config + self.layer_idx = layer_idx + if layer_idx is None: + logger.warning_once( + f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will " + "lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` " + "when creating this class." + ) + + self.attention_dropout = config.attention_dropout + self.hidden_size = config.hidden_size + self.num_heads = config.num_attention_heads + self.head_dim = self.hidden_size // self.num_heads + self.num_key_value_heads = config.num_key_value_heads + self.num_key_value_groups = self.num_heads // self.num_key_value_heads + self.max_position_embeddings = config.max_position_embeddings + self.rope_theta = config.rope_theta + self.partial_rotary_factor = config.partial_rotary_factor + self.is_causal = True + + if (self.head_dim * self.num_heads) != self.hidden_size: + raise ValueError( + f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}" + f" and `num_heads`: {self.num_heads})." + ) + + self.qkv_quant_type = quant_config["qkv"] + self.o_quant_type = quant_config["out"] + self.q_proj = W8A8BFP32OFP32Linear(self.hidden_size, self.num_heads * self.head_dim, use_bias=True, act_quant=self.qkv_quant_type) + self.k_proj = W8A8BFP32OFP32Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, use_bias=True, act_quant=self.qkv_quant_type) + self.v_proj = W8A8BFP32OFP32Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, use_bias=True, act_quant=self.qkv_quant_type) + self.dense = W8A8BFP32OFP32LinearWithQuantScale(self.num_heads * self.head_dim, self.hidden_size, use_bias=True, act_quant=self.o_quant_type) + + self.qk_layernorm = config.qk_layernorm + # false + if self.qk_layernorm: + self.q_layernorm = nn.LayerNorm( + config.hidden_size // self.num_heads, eps=config.layer_norm_eps, elementwise_affine=True + ) + self.k_layernorm = nn.LayerNorm( + config.hidden_size // self.num_heads, eps=config.layer_norm_eps, elementwise_affine=True + ) + + self._init_rope() + + _init_rope = PhiAttention._init_rope + forward = PhiAttention.forward + + @staticmethod + @torch.no_grad() + def from_float(module: PhiAttention, + config: PhiConfig, + quant_config: dict[str, str], + attn_input_scale: float, + q_output_scale: float, + k_output_scale: float, + v_output_scale: float, + dense_input_scale: float): + int8_module = Int8PhiAttention(config, quant_config) + # we do not impelement attn for now bacuase we want use paged attention + int8_module.q_proj = W8A8BFP32OFP32Linear.from_float(module.q_proj, attn_input_scale, act_quant=int8_module.qkv_quant_type) + int8_module.k_proj = W8A8BFP32OFP32Linear.from_float(module.k_proj, attn_input_scale, act_quant=int8_module.qkv_quant_type) + int8_module.v_proj = W8A8BFP32OFP32Linear.from_float(module.v_proj, attn_input_scale, act_quant=int8_module.qkv_quant_type) + int8_module.dense = W8A8BFP32OFP32LinearWithQuantScale.from_float( + module.dense, dense_input_scale, act_quant=int8_module.o_quant_type) + return int8_module +class Int8PhiDecoderLayer(nn.Module): + def __init__(self, config: PhiConfig, quant_config: dict[str, str], layer_idx: int): + super().__init__() + self.self_attn = Int8PhiAttention(config, quant_config, layer_idx=layer_idx) + self.mlp = Int8PhiMLP(config, quant_config) + self.input_layernorm = Int8PhiLayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.resid_dropout = nn.Dropout(config.resid_pdrop) + + forward = PhiDecoderLayer.forward + + @staticmethod + def from_float(module: PhiDecoderLayer, + config: PhiConfig, + quant_config: dict[str, str], + attn_input_scale: float, + q_output_scale: float, + k_output_scale: float, + v_output_scale: float, + dense_input_scale: float, + fc1_input_scale: float, + fc2_input_scale: float + ): + int8_module = Int8PhiDecoderLayer( + config, + quant_config, + module.self_attn.layer_idx + ) + int8_module.self_attn = Int8PhiAttention.from_float( + module.self_attn, + config, + quant_config, + attn_input_scale, + q_output_scale, + k_output_scale, + v_output_scale, + dense_input_scale + ) + int8_module.mlp = Int8PhiMLP.from_float( + module.mlp, + config, + quant_config, + fc1_input_scale, + fc2_input_scale + ) + if quant_config["qkv"] == "per-tensor": + int8_module.input_layernorm = Int8PhiLayerNorm.from_float( + module.input_layernorm, + attn_input_scale + ) + else: + int8_module.input_layernorm = module.input_layernorm + + return int8_module +class Int8PhiModel(PhiPreTrainedModel): + """ + Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`PhiDecoderLayer`] + + Args: + config: PhiConfig + """ + + def __init__(self, config: PhiConfig, quant_config: dict[str, str]): + super().__init__(config) + self.config = config + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) + self.embed_dropout = nn.Dropout(config.embd_pdrop) + self.layers = nn.ModuleList( + [Int8PhiDecoderLayer(config, quant_config, layer_idx) for layer_idx in range(config.num_hidden_layers)] + ) + self.final_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + + self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2" + self._use_sdpa = config._attn_implementation == "sdpa" + + self.gradient_checkpointing = False + # Initialize weights and apply final processing + self.post_init() + + get_input_embeddings = PhiModel.get_input_embeddings + set_input_embeddings = PhiModel.set_input_embeddings + forward = PhiModel.forward + + @staticmethod + def from_float(module, decoder_layer_scales, quant_config): + int8_module = Int8PhiModel(module.config, quant_config) + + int8_module.embed_tokens = module.embed_tokens + int8_module.final_layernorm = module.final_layernorm + + for i, layer in enumerate(module.layers): + int8_module.layers[i] = Int8PhiDecoderLayer.from_float( + layer, module.config, quant_config, **decoder_layer_scales[i]) + return int8_module +class Int8PhiForCausalLM(PhiPreTrainedModel): + _tied_weights_keys = ["lm_head.weight"] + + def __init__(self, config, quant_config): + super().__init__(config) + self.model = Int8PhiModel(config, quant_config) + self.vocab_size = config.vocab_size + # no need to quant + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=True) + + # Initialize weights and apply final processing + self.post_init() + + get_input_embeddings = PhiForCausalLM.get_input_embeddings + set_input_embeddings = PhiForCausalLM.set_input_embeddings + get_output_embeddings = PhiForCausalLM.get_output_embeddings + set_output_embeddings = PhiForCausalLM.set_output_embeddings + set_decoder = PhiForCausalLM.set_decoder + get_decoder = PhiForCausalLM.get_decoder + forward = PhiForCausalLM.forward + prepare_inputs_for_generation = PhiForCausalLM.prepare_inputs_for_generation + _reorder_cache = PhiForCausalLM._reorder_cache + + @staticmethod + def from_float(module, decoder_layer_scales, quant_config): + int8_module = Int8PhiForCausalLM(module.config, quant_config) + print("start trans into int8, this might take a while") + int8_module.model = Int8PhiModel.from_float( + module.model, decoder_layer_scales, quant_config) + int8_module.lm_head = module.lm_head + return int8_module \ No newline at end of file diff --git a/autosmoothquant/models/qwen2.py b/autosmoothquant/models/qwen2.py new file mode 100644 index 0000000..35197fc --- /dev/null +++ b/autosmoothquant/models/qwen2.py @@ -0,0 +1,273 @@ +import torch +from torch import nn +from transformers.models.qwen2.modeling_qwen2 import ( + Qwen2RMSNorm, + Qwen2MLP, + Qwen2Attention, + Qwen2DecoderLayer, + Qwen2PreTrainedModel, + Qwen2Model, + Qwen2ForCausalLM, + Qwen2RotaryEmbedding +) +from transformers.models.qwen2.configuration_qwen2 import Qwen2Config +from transformers.activations import ACT2FN +from typing import Optional +import sys + +from layers.nn.linear import W8A8BFP32OFP32LinearWithQuantScale, W8A8BFP32OFP32Linear +from transformers.utils import logging + +logger = logging.get_logger(__name__) + + +class Int8Qwen2RMSNorm(Qwen2RMSNorm): + + @staticmethod + def from_float(module: Qwen2RMSNorm, + output_scale: float): + int8_module = Int8Qwen2RMSNorm(module.weight.numel(), module.variance_epsilon) + + int8_module.weight.to(module.weight.dtype) + int8_module.weight = nn.Parameter(module.weight / output_scale) + + return int8_module + + +class Int8Qwen2Attention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__( + self, + config: Qwen2Config, + quant_config: dict[str, str], + layer_idx: Optional[int] = None + ): + super().__init__() + self.config = config + self.layer_idx = layer_idx + if layer_idx is None: + logger.warning_once( + f"Instantiating {self.__class__.__name__} without passing `layer_idx` is not recommended and will " + "to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` " + "when creating this class." + ) + + self.attention_dropout = config.attention_dropout + self.hidden_size = config.hidden_size + self.num_heads = config.num_attention_heads + self.head_dim = self.hidden_size // self.num_heads + self.num_key_value_heads = config.num_key_value_heads + self.num_key_value_groups = self.num_heads // self.num_key_value_heads + self.max_position_embeddings = config.max_position_embeddings + self.rope_theta = config.rope_theta + self.is_causal = True + + if (self.head_dim * self.num_heads) != self.hidden_size: + raise ValueError( + f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}" + f" and `num_heads`: {self.num_heads})." + ) + + self.qkv_quant_type = quant_config["qkv"] + self.o_quant_type = quant_config["out"] + self.k_proj = W8A8BFP32OFP32Linear(self.hidden_size, self.num_heads * self.head_dim, use_bias=True, + act_quant=self.qkv_quant_type) + self.v_proj = W8A8BFP32OFP32Linear(self.hidden_size, self.num_heads * self.head_dim, use_bias=True, + act_quant=self.qkv_quant_type) + self.q_proj = W8A8BFP32OFP32Linear(self.hidden_size, self.num_heads * self.head_dim, use_bias=True, + act_quant=self.qkv_quant_type) + self.o_proj = W8A8BFP32OFP32LinearWithQuantScale(self.num_heads * self.head_dim, self.hidden_size, + act_quant=self.o_quant_type) + self.rotary_emb = Qwen2RotaryEmbedding( + self.head_dim, + max_position_embeddings=self.max_position_embeddings, + base=self.rope_theta, + ) + forward = Qwen2Attention.forward + + @staticmethod + @torch.no_grad() + def from_float(module: Qwen2Attention, + config: Qwen2Config, + quant_config: dict[str, str], + attn_input_scale: float, + q_output_scale: float, + k_output_scale: float, + v_output_scale: float, + out_input_scale: float): + int8_module = Int8Qwen2Attention(config, quant_config) + # we do not impelement attn for now bacuase we want use paged attention + int8_module.q_proj = W8A8BFP32OFP32Linear.from_float(module.q_proj, attn_input_scale, + act_quant=int8_module.qkv_quant_type) + int8_module.k_proj = W8A8BFP32OFP32Linear.from_float(module.k_proj, attn_input_scale, + act_quant=int8_module.qkv_quant_type) + int8_module.v_proj = W8A8BFP32OFP32Linear.from_float(module.v_proj, attn_input_scale, + act_quant=int8_module.qkv_quant_type) + int8_module.o_proj = W8A8BFP32OFP32LinearWithQuantScale.from_float( + module.o_proj, out_input_scale, act_quant=int8_module.o_quant_type) + return int8_module + + +class Int8Qwen2MLP(nn.Module): + def __init__(self, config: Qwen2Config, quant_config: dict[str, str]): + super().__init__() + self.config = config + self.hidden_size = config.hidden_size + self.intermediate_size = config.intermediate_size + self.gate_up_quant_type = quant_config["fc1"] + self.down_quant_type = quant_config["fc2"] + self.gate_proj = W8A8BFP32OFP32Linear(self.hidden_size, self.intermediate_size, + act_quant=self.gate_up_quant_type) + self.up_proj = W8A8BFP32OFP32Linear(self.hidden_size, self.intermediate_size, act_quant=self.gate_up_quant_type) + self.down_proj = W8A8BFP32OFP32LinearWithQuantScale(self.intermediate_size, self.hidden_size, + act_quant=self.down_quant_type) + self.act_fn = ACT2FN[config.hidden_act] + + forward = Qwen2MLP.forward + + @staticmethod + @torch.no_grad() + def from_float(module: Qwen2MLP, + config: Qwen2Config, + quant_config: dict[str, str], + gate_input_scale: float, + down_input_scale: float): + int8_module = Int8Qwen2MLP(config, quant_config) + int8_module.gate_proj = W8A8BFP32OFP32Linear.from_float(module.gate_proj, gate_input_scale, + act_quant=int8_module.gate_up_quant_type) + int8_module.up_proj = W8A8BFP32OFP32Linear.from_float(module.up_proj, gate_input_scale, + act_quant=int8_module.gate_up_quant_type) + int8_module.down_proj = W8A8BFP32OFP32LinearWithQuantScale.from_float( + module.down_proj, + down_input_scale, + act_quant=int8_module.down_quant_type) + return int8_module + + +class Int8Qwen2DecoderLayer(nn.Module): + def __init__(self, config: Qwen2Config, quant_config: dict[str, str], layer_idx: int): + super().__init__() + self.hidden_size = config.hidden_size + # only support LlamaAttention for now. TODO: support LlamaFlashAttention2 and LlamaSdpaAttention + self.self_attn = Int8Qwen2Attention(config, quant_config, layer_idx) + self.mlp = Int8Qwen2MLP(config, quant_config) + self.input_layernorm = Int8Qwen2RMSNorm(self.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = Int8Qwen2RMSNorm(self.hidden_size, eps=config.rms_norm_eps) + + forward = Qwen2DecoderLayer.forward + + @staticmethod + def from_float(module: Qwen2DecoderLayer, + config: Qwen2Config, + quant_config: dict[str, str], + attn_input_scale: float, + q_output_scale: float, + k_output_scale: float, + v_output_scale: float, + out_input_scale: float, + gate_input_scale: float, + down_input_scale: float + ): + int8_module = Int8Qwen2DecoderLayer( + config, + quant_config, + module.self_attn.layer_idx + ) + int8_module.self_attn = Int8Qwen2Attention.from_float( + module.self_attn, + config, + quant_config, + attn_input_scale, + q_output_scale, + k_output_scale, + v_output_scale, + out_input_scale + ) + int8_module.mlp = Int8Qwen2MLP.from_float( + module.mlp, + config, + quant_config, + gate_input_scale, + down_input_scale + ) + if quant_config["qkv"] == "per-tensor": + int8_module.input_layernorm = Int8Qwen2RMSNorm.from_float( + module.input_layernorm, + attn_input_scale + ) + else: + int8_module.input_layernorm = module.input_layernorm + if quant_config["fc1"] == "per-tensor": + int8_module.post_attention_layernorm = Int8Qwen2RMSNorm.from_float( + module.post_attention_layernorm, + gate_input_scale + ) + else: + int8_module.post_attention_layernorm = module.post_attention_layernorm + return int8_module + + +class Int8Qwen2Model(Qwen2PreTrainedModel): + def __init__(self, config: Qwen2Config, quant_config: dict[str, str]): + super().__init__(config) + self.config = config + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) + self.layers = nn.ModuleList( + [Int8Qwen2DecoderLayer(config, quant_config, layer_idx) for layer_idx in range(config.num_hidden_layers)]) + self._attn_implementation = config._attn_implementation + self.norm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.gradient_checkpointing = False + # Initialize weights and apply final processing + self.post_init() + + get_input_embeddings = Qwen2Model.get_input_embeddings + set_input_embeddings = Qwen2Model.set_input_embeddings + forward = Qwen2Model.forward + + + @staticmethod + def from_float(module, decoder_layer_scales, quant_config): + int8_module = Int8Qwen2Model(module.config, quant_config) + + int8_module.embed_tokens = module.embed_tokens + int8_module.norm = module.norm + + for i, layer in enumerate(module.layers): + int8_module.layers[i] = Int8Qwen2DecoderLayer.from_float( + layer, module.config, quant_config, **decoder_layer_scales[i]) + return int8_module + + +class Int8Qwen2ForCausalLM(Qwen2PreTrainedModel): + def __init__(self, config, quant_config): + super().__init__(config) + self.config = config + self.vocab_size = config.vocab_size + self.model = Int8Qwen2Model(config, quant_config) + # no need to quant + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + # Initialize weights and apply final processing + self.post_init() + + get_input_embeddings = Qwen2ForCausalLM.get_input_embeddings + set_input_embeddings = Qwen2ForCausalLM.set_input_embeddings + get_output_embeddings = Qwen2ForCausalLM.get_output_embeddings + set_output_embeddings = Qwen2ForCausalLM.set_output_embeddings + set_decoder = Qwen2ForCausalLM.set_decoder + get_decoder = Qwen2ForCausalLM.get_decoder + forward = Qwen2ForCausalLM.forward + prepare_inputs_for_generation = Qwen2ForCausalLM.prepare_inputs_for_generation + _reorder_cache = Qwen2ForCausalLM._reorder_cache + + @staticmethod + def from_float(module, decoder_layer_scales, quant_config): + int8_module = Int8Qwen2ForCausalLM(module.config, quant_config) + print("start trans into int8, this might take a while") + int8_module.model = Int8Qwen2Model.from_float( + module.model, decoder_layer_scales, quant_config) + int8_module.lm_head = module.lm_head + return int8_module diff --git a/autosmoothquant/quantize/calibration.py b/autosmoothquant/quantize/calibration.py index dccedf2..467efca 100644 --- a/autosmoothquant/quantize/calibration.py +++ b/autosmoothquant/quantize/calibration.py @@ -8,7 +8,7 @@ from functools import partial import numpy as np from tqdm import tqdm -from autosmoothquant.models import _MODEL_TYPE +from models import _MODEL_TYPE def _model_preprocess(model): @@ -19,12 +19,13 @@ def _model_preprocess(model): original_top_k = model.model.layers[0].block_sparse_moe.top_k num_local_experts = getattr(model.config, "num_local_experts") info_dict["original_top_k"] = original_top_k - #FIXME: To get all expert act scales, we set top_k to the number of total experts + # FIXME: To get all expert act scales, we set top_k to the number of total experts # which might have negative effects on generating sclaes for layer in model.model.layers: layer.block_sparse_moe.top_k = num_local_experts return info_dict + def _model_postprocess(model, info_dict): if info_dict["model_type"] == "mixtral": original_top_k = info_dict["original_top_k"] @@ -32,13 +33,14 @@ def _model_postprocess(model, info_dict): for layer in model.model.layers: layer.block_sparse_moe.top_k = original_top_k + def get_act_scales(model, tokenizer, dataset_path, num_samples=512, seq_len=512): model.eval() info_dict = _model_preprocess(model) device = next(model.parameters()).device # Only support pretraining_tp=1 when capturing activation for now if hasattr(model.config, "pretraining_tp"): - model.config.pretraining_tp = 1 + model.config.pretraining_tp = 1 act_scales = {} def stat_tensor(name, tensor): @@ -73,30 +75,57 @@ def stat_input_hook(m, x, y, name): for h in hooks: h.remove() - + _model_postprocess(model, info_dict) return act_scales + @torch.no_grad() def collect_transformers_layer_scales(model, act_dict): decoder_layer_scales = [] for idx in range(model.config.num_hidden_layers): scale_dict = {} scale_dict["attn_input_scale"] = act_dict[ - f"model.decoder.layers.{idx}.self_attn.q_proj"]['input'] / 127 + f"model.decoder.layers.{idx}.self_attn.q_proj"]['input'] / 127 scale_dict["q_output_scale"] = act_dict[ - f"model.decoder.layers.{idx}.self_attn.q_proj"]['output'] / 127 + f"model.decoder.layers.{idx}.self_attn.q_proj"]['output'] / 127 scale_dict["k_output_scale"] = act_dict[ - f"model.decoder.layers.{idx}.self_attn.k_proj"]['output'] / 127 + f"model.decoder.layers.{idx}.self_attn.k_proj"]['output'] / 127 scale_dict["v_output_scale"] = act_dict[ - f"model.decoder.layers.{idx}.self_attn.v_proj"]['output'] / 127 + f"model.decoder.layers.{idx}.self_attn.v_proj"]['output'] / 127 scale_dict["out_input_scale"] = act_dict[ - f"model.decoder.layers.{idx}.self_attn.out_proj"]['input'] / 127 + f"model.decoder.layers.{idx}.self_attn.out_proj"]['input'] / 127 + scale_dict["fc1_input_scale"] = act_dict[ + f"model.decoder.layers.{idx}.fc1"]['input'] / 127 + scale_dict["fc2_input_scale"] = act_dict[ + f"model.decoder.layers.{idx}.fc2"]["input"] / 127 + decoder_layer_scales.append(scale_dict) + + return decoder_layer_scales + + +@torch.no_grad() +def collect_phi2_layer_scales(model, act_dict): + decoder_layer_scales = [] + for idx in range(model.config.num_hidden_layers): + scale_dict = {} + # self attenion scales + scale_dict["attn_input_scale"] = act_dict[ + f"model.layers.{idx}.self_attn.q_proj"]['input'] / 127 + scale_dict["q_output_scale"] = act_dict[ + f"model.layers.{idx}.self_attn.q_proj"]['output'] / 127 + scale_dict["k_output_scale"] = act_dict[ + f"model.layers.{idx}.self_attn.k_proj"]['output'] / 127 + scale_dict["v_output_scale"] = act_dict[ + f"model.layers.{idx}.self_attn.v_proj"]['output'] / 127 + scale_dict["dense_input_scale"] = act_dict[ + f"model.layers.{idx}.self_attn.dense"]['input'] / 127 + # mlp scales scale_dict["fc1_input_scale"] = act_dict[ - f"model.decoder.layers.{idx}.fc1"]['input'] / 127 + f"model.layers.{idx}.mlp.fc1"]['input'] / 127 scale_dict["fc2_input_scale"] = act_dict[ - f"model.decoder.layers.{idx}.fc2"]["input"] / 127 + f"model.layers.{idx}.mlp.fc2"]["input"] / 127 decoder_layer_scales.append(scale_dict) return decoder_layer_scales @@ -108,62 +137,64 @@ def collect_llama_layer_scales(model, act_dict): for idx in range(model.config.num_hidden_layers): scale_dict = {} scale_dict["attn_input_scale"] = act_dict[ - f"model.layers.{idx}.self_attn.q_proj"]['input'] / 127 + f"model.layers.{idx}.self_attn.q_proj"]['input'] / 127 scale_dict["q_output_scale"] = act_dict[ - f"model.layers.{idx}.self_attn.q_proj"]['output'] / 127 + f"model.layers.{idx}.self_attn.q_proj"]['output'] / 127 scale_dict["k_output_scale"] = act_dict[ - f"model.layers.{idx}.self_attn.k_proj"]['output'] / 127 + f"model.layers.{idx}.self_attn.k_proj"]['output'] / 127 scale_dict["v_output_scale"] = act_dict[ - f"model.layers.{idx}.self_attn.v_proj"]['output'] / 127 + f"model.layers.{idx}.self_attn.v_proj"]['output'] / 127 scale_dict["out_input_scale"] = act_dict[ - f"model.layers.{idx}.self_attn.o_proj"]['input'] / 127 + f"model.layers.{idx}.self_attn.o_proj"]['input'] / 127 # mlp scales scale_dict["gate_input_scale"] = act_dict[ - f"model.layers.{idx}.mlp.gate_proj"]['input'] / 127 + f"model.layers.{idx}.mlp.gate_proj"]['input'] / 127 scale_dict["down_input_scale"] = act_dict[ - f"model.layers.{idx}.mlp.down_proj"]["input"] / 127 + f"model.layers.{idx}.mlp.down_proj"]["input"] / 127 decoder_layer_scales.append(scale_dict) return decoder_layer_scales + @torch.no_grad() def collect_baichuan_layer_scales(model, act_dict): decoder_layer_scales = [] for idx in range(model.config.num_hidden_layers): scale_dict = {} scale_dict["attn_input_scale"] = act_dict[ - f"model.layers.{idx}.self_attn.W_pack"]['input'] / 127 + f"model.layers.{idx}.self_attn.W_pack"]['input'] / 127 scale_dict["attn_output_scale"] = act_dict[ - f"model.layers.{idx}.self_attn.W_pack"]['output'] / 127 + f"model.layers.{idx}.self_attn.W_pack"]['output'] / 127 scale_dict["out_input_scale"] = act_dict[ - f"model.layers.{idx}.self_attn.o_proj"]['input'] / 127 + f"model.layers.{idx}.self_attn.o_proj"]['input'] / 127 # mlp scales scale_dict["gate_input_scale"] = act_dict[ - f"model.layers.{idx}.mlp.gate_proj"]['input'] / 127 + f"model.layers.{idx}.mlp.gate_proj"]['input'] / 127 scale_dict["down_input_scale"] = act_dict[ - f"model.layers.{idx}.mlp.down_proj"]["input"] / 127 + f"model.layers.{idx}.mlp.down_proj"]["input"] / 127 decoder_layer_scales.append(scale_dict) return decoder_layer_scales + @torch.no_grad() def collect_mixtral_layer_scales(model, act_dict): decoder_layer_scales = [] for idx in range(model.config.num_hidden_layers): scale_dict = {} scale_dict["attn_input_scale"] = act_dict[ - f"model.layers.{idx}.self_attn.q_proj"]['input'] / 127 + f"model.layers.{idx}.self_attn.q_proj"]['input'] / 127 scale_dict["q_output_scale"] = act_dict[ - f"model.layers.{idx}.self_attn.q_proj"]['output'] / 127 + f"model.layers.{idx}.self_attn.q_proj"]['output'] / 127 scale_dict["k_output_scale"] = act_dict[ - f"model.layers.{idx}.self_attn.k_proj"]['output'] / 127 + f"model.layers.{idx}.self_attn.k_proj"]['output'] / 127 scale_dict["v_output_scale"] = act_dict[ - f"model.layers.{idx}.self_attn.v_proj"]['output'] / 127 + f"model.layers.{idx}.self_attn.v_proj"]['output'] / 127 scale_dict["out_input_scale"] = act_dict[ - f"model.layers.{idx}.self_attn.o_proj"]['input'] / 127 + f"model.layers.{idx}.self_attn.o_proj"]['input'] / 127 # moe scales scale_dict["moe_input_scale"] = act_dict[ - f"model.layers.{idx}.block_sparse_moe.gate"]['input'] / 127 + f"model.layers.{idx}.block_sparse_moe.gate"]['input'] / 127 down_input_scales = [] num_local_experts = getattr(model.config, "num_local_experts") for i in range(num_local_experts): @@ -173,13 +204,39 @@ def collect_mixtral_layer_scales(model, act_dict): return decoder_layer_scales + +@torch.no_grad() +def collect_qwen2_layer_scales(model, act_dict): + decoder_layer_scales = [] + for idx in range(model.config.num_hidden_layers): + scale_dict = {} + scale_dict["attn_input_scale"] = act_dict[ + f"model.layers.{idx}.self_attn.q_proj"]['input'] / 127 + scale_dict["q_output_scale"] = act_dict[ + f"model.layers.{idx}.self_attn.q_proj"]['output'] / 127 + scale_dict["k_output_scale"] = act_dict[ + f"model.layers.{idx}.self_attn.k_proj"]['output'] / 127 + scale_dict["v_output_scale"] = act_dict[ + f"model.layers.{idx}.self_attn.v_proj"]['output'] / 127 + scale_dict["out_input_scale"] = act_dict[ + f"model.layers.{idx}.self_attn.o_proj"]['input'] / 127 + # mlp scales + scale_dict["gate_input_scale"] = act_dict[ + f"model.layers.{idx}.mlp.gate_proj"]['input'] / 127 + scale_dict["down_input_scale"] = act_dict[ + f"model.layers.{idx}.mlp.down_proj"]["input"] / 127 + decoder_layer_scales.append(scale_dict) + + return decoder_layer_scales + + @torch.no_grad() def get_static_decoder_layer_scales(model, tokenizer, dataset_path, num_samples=512, seq_len=512, - model_type = "transformers" + model_type="transformers" ): model.eval() device = next(model.parameters()).device @@ -229,6 +286,10 @@ def stat_io_hook(m, x, y, name): decoder_layer_scales = collect_baichuan_layer_scales(model, act_dict) elif model_type == "mixtral": decoder_layer_scales = collect_mixtral_layer_scales(model, act_dict) + elif model_type == "phi2": + decoder_layer_scales = collect_phi2_layer_scales(model, act_dict) + elif model_type == "qwen2": + decoder_layer_scales = collect_qwen2_layer_scales(model, act_dict) else: raise ValueError(f"unsupport model type: {model_type}") diff --git a/autosmoothquant/quantize/smooth.py b/autosmoothquant/quantize/smooth.py index ffa8901..56fb80e 100644 --- a/autosmoothquant/quantize/smooth.py +++ b/autosmoothquant/quantize/smooth.py @@ -4,7 +4,9 @@ from transformers.models.opt.modeling_opt import OPTDecoderLayer from transformers.models.llama.modeling_llama import LlamaDecoderLayer, LlamaRMSNorm from transformers.models.mixtral.modeling_mixtral import MixtralDecoderLayer, MixtralRMSNorm -from autosmoothquant.thirdparty.baichuan.modeling_baichuan import RMSNorm, BaichuanLayer +from transformers.models.phi.modeling_phi import PhiDecoderLayer +from transformers.models.qwen2.modeling_qwen2 import Qwen2DecoderLayer, Qwen2RMSNorm +from thirdparty.baichuan.modeling_baichuan import RMSNorm, BaichuanLayer @torch.no_grad() @@ -20,6 +22,8 @@ def smooth_ln_fcs(ln, fcs, act_scales, model_type = "transformers", alpha=0.5): assert isinstance(ln, RMSNorm) elif model_type == "mixtral": assert isinstance(ln, MixtralRMSNorm) + elif model_type == "qwen2": + assert isinstance(ln, Qwen2RMSNorm) else: assert isinstance(ln, nn.LayerNorm) @@ -33,7 +37,8 @@ def smooth_ln_fcs(ln, fcs, act_scales, model_type = "transformers", alpha=0.5): ).clamp(min=1e-5).to(device).to(dtype) ln.weight.div_(scales) - if model_type == "transformers": + # ln.bias.div_(scales) + if model_type == "phi2": ln.bias.div_(scales) for fc in fcs: @@ -91,5 +96,24 @@ def smooth_lm(model, scales, alpha=0.5): fcs.append(expert.w3) fcs_input_scales = scales[name + '.block_sparse_moe.gate'] smooth_ln_fcs(ffn_ln, fcs, fcs_input_scales, "mixtral", alpha) + elif isinstance(module, PhiDecoderLayer): + print(f"smooth phi model: {name}") + attn_ln = module.input_layernorm + fc1 = module.mlp.fc1 + # fc2 = module.mlp.fc2 + qkvfc1 = [module.self_attn.q_proj, + module.self_attn.k_proj, module.self_attn.v_proj, fc1] + qkv_input_scales = scales[name + '.self_attn.q_proj'] + smooth_ln_fcs(attn_ln, qkvfc1, qkv_input_scales, "phi2", alpha) + elif isinstance(module, Qwen2DecoderLayer): + print(f"smooth qwen model: {name}") + attn_ln = module.input_layernorm #attention forward norm + qkv = [module.self_attn.q_proj, + module.self_attn.k_proj, module.self_attn.v_proj] + qkv_input_scales = scales[name + '.self_attn.q_proj'] + smooth_ln_fcs(attn_ln, qkv, qkv_input_scales, "qwen2", alpha) - + ffn_ln = module.post_attention_layernorm #feed forward norm + fcs = [module.mlp.gate_proj, module.mlp.up_proj] + fcs_input_scales = scales[name + '.mlp.gate_proj'] + smooth_ln_fcs(ffn_ln, fcs, fcs_input_scales, "qwen2", alpha) diff --git a/autosmoothquant/utils/utils.py b/autosmoothquant/utils/utils.py index 997f6f9..dcdadc8 100644 --- a/autosmoothquant/utils/utils.py +++ b/autosmoothquant/utils/utils.py @@ -39,7 +39,14 @@ def parse_quant_config(config_path): def build_model_and_tokenizer(model_name, trust_remote_code: bool = True): tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=trust_remote_code, model_max_length=512) kwargs = {"torch_dtype": torch.float16, "device_map": "sequential"} - model = AutoModelForCausalLM.from_pretrained(model_name, trust_remote_code=trust_remote_code, **kwargs) + if model_name == "phi2": + model = PhiForCausalLM.from_pretrained( + model_name, device_map="auto", torch_dtype=torch.float16) + elif model_name == "qwen2": + model = Qwen2ForCausalLM.from_pretrained( + model_name, device_map="auto", torch_dtype=torch.float16) + else: + model = AutoModelForCausalLM.from_pretrained(model_name, trust_remote_code=trust_remote_code, **kwargs) return model, tokenizer def get_model_architecture(config) -> Type[nn.Module]: