Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
76 changes: 76 additions & 0 deletions autosmoothquant/examples/ppl_eval.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
import torch
import torch.nn as nn
from transformers import AutoTokenizer, AutoModelForCausalLM

from models.phi2 import Int8PhiForCausalLM
from models.llama import Int8LlamaForCausalLM
from models.qwen2 import Int8Qwen2ForCausalLM
import tqdm
import os
from datasets import load_dataset
import argparse
from utils import get_config, get_model_architecture, build_model_and_tokenizer, parse_quant_config
from transformers.models.phi.modeling_phi import PhiForCausalLM
from transformers.models.llama.modeling_llama import LlamaForCausalLM
from transformers.models.qwen2.modeling_qwen2 import Qwen2ForCausalLM
parser = argparse.ArgumentParser()
parser.add_argument("--alpha", type=float, default=0.5)
parser.add_argument("--model_path", type=str, default="quantized_model/qwen2/qwen2-smoothquant")


args = parser.parse_args()
alpha = args.alpha
model_path = args.model_path


class Evaluator:
def __init__(self, dataset, tokenizer, device, n_samples=40):
self.dataset = dataset
self.tokenizer = tokenizer
self.device = device

self.dataset = tokenizer(
"\n\n".join(dataset["text"]), return_tensors="pt"
).input_ids.to(device)

self.n_samples = n_samples

@torch.no_grad()
def evaluate(self, model):
model.eval()
nlls = []
n_samples = self.n_samples if self.n_samples else self.dataset.size(1) // 2048
for i in tqdm.tqdm(range(n_samples), desc="Evaluating..."):
batch = self.dataset[:, (i * 2048) : ((i + 1) * 2048)].to(model.device)
with torch.no_grad():
lm_logits = model(batch).logits
shift_logits = lm_logits[:, :-1, :].contiguous().float()
shift_labels = self.dataset[:, (i * 2048) : ((i + 1) * 2048)][:, 1:]
loss_fct = nn.CrossEntropyLoss()
loss = loss_fct(
shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)
)
neg_log_likelihood = loss.float() * 2048
nlls.append(neg_log_likelihood)

return torch.exp(torch.stack(nlls).sum() / (n_samples * 2048))


tokenizer = AutoTokenizer.from_pretrained(model_path)
dataset = load_dataset("wikitext", "wikitext-2-raw-v1", split="test")
evaluator = Evaluator(dataset, tokenizer, "cuda")
config_path = os.path.join(args.model_path, "quant_config.json")
quant_config = parse_quant_config(config_path)

model = Int8Qwen2ForCausalLM.from_pretrained(args.model_path, quant_config,
device_map="sequential")
# model = Int8PhiForCausalLM.from_pretrained(args.model_path, quant_config,
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

counld you please add a arg to choose model arch?

# device_map="sequential")
# model = Int8PhiForCausalLM.from_pretrained(args.model_path, quant_config, attn_implementation="eager",
# device_map="sequential")
# model = PhiForCausalLM.from_pretrained(
# args.model_path, device_map="auto", torch_dtype=torch.float16)
# model = Qwen2ForCausalLM.from_pretrained(
# args.model_path, device_map="auto", torch_dtype=torch.float16)
ppl = evaluator.evaluate(model)
print(f"Perplexity: {ppl}")
10 changes: 9 additions & 1 deletion autosmoothquant/examples/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,9 @@
import argparse
import json

from autosmoothquant.models import Int8LlamaForCausalLM, Int8OPTForCausalLM, Int8BaichuanForCausalLM, Int8MixtralForCausalLM
from autosmoothquant.models import (Int8LlamaForCausalLM, Int8OPTForCausalLM,
Int8BaichuanForCausalLM, Int8MixtralForCausalLM,
Int8PhiForCausalLM,Int8Qwen2ForCausalLM)
from autosmoothquant.utils import parse_quant_config
from transformers import AutoTokenizer

Expand Down Expand Up @@ -39,6 +41,12 @@ def main():
model = Int8OPTForCausalLM.from_pretrained(args.model_path, quant_config, attn_implementation="eager", device_map="sequential")
elif args.model_class == "mixtral":
model = Int8MixtralForCausalLM.from_pretrained(args.model_path, quant_config, attn_implementation="eager", device_map="sequential")
elif args.model_class == "phi2":
model = Int8PhiForCausalLM.from_pretrained(args.model_path, quant_config, attn_implementation="eager",
device_map="sequential")
elif args.model_class == "qwen2":
model = Int8Qwen2ForCausalLM.from_pretrained(args.model_path, quant_config, attn_implementation="eager",
device_map="sequential")
else:
raise ValueError(
f"Model type {args.model_class} are not supported for now.")
Expand Down
10 changes: 8 additions & 2 deletions autosmoothquant/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,22 +2,28 @@
from .llama import Int8LlamaForCausalLM
from .mixtral import Int8MixtralForCausalLM
from .opt import Int8OPTForCausalLM
from .phi2 import Int8PhiForCausalLM
from .qwen2 import Int8Qwen2ForCausalLM
from autosmoothquant.thirdparty.baichuan.configuration_baichuan import BaichuanConfig

_MODEL_REGISTRY = {
"LlamaForCausalLM": Int8LlamaForCausalLM,
"LLaMAForCausalLM": Int8LlamaForCausalLM,
"BaichuanForCausalLM": Int8BaichuanForCausalLM,
"OPTForCausalLM": Int8OPTForCausalLM,
"MixtralForCausalLM": Int8MixtralForCausalLM
"MixtralForCausalLM": Int8MixtralForCausalLM,
"PhiForCausalLM": Int8PhiForCausalLM,
"Qwen2ForCausalLM": Int8Qwen2ForCausalLM
}

_MODEL_TYPE = {
"LlamaForCausalLM": "llama",
"LLaMAForCausalLM": "llama",
"BaichuanForCausalLM": "baichuan",
"OPTForCausalLM": "transformers",
"MixtralForCausalLM": "mixtral"
"MixtralForCausalLM": "mixtral",
"PhiForCausalLM": "phi",
"Qwen2ForCausalLM": "qwen2"
}

_CONFIG_REGISTRY = {
Expand Down
255 changes: 255 additions & 0 deletions autosmoothquant/models/phi2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,255 @@
""" PyTorch Phi model."""

from typing import List, Optional, Tuple, Union
import torch
import torch.nn.functional as F
import torch.utils.checkpoint
from torch import nn
from transformers.models.phi.modeling_phi import (
PhiMLP,
PhiAttention,
PhiDecoderLayer,
PhiPreTrainedModel,
PhiModel,
PhiForCausalLM,
)

from transformers.activations import ACT2FN
from layers.nn.linear import W8A8BFP32OFP32LinearWithQuantScale, W8A8BFP32OFP32Linear
from transformers.utils import logging
from transformers.models.phi.configuration_phi import PhiConfig

logger = logging.get_logger(__name__)
class Int8PhiLayerNorm(nn.LayerNorm):
@staticmethod
def from_float(module: nn.LayerNorm, output_scale: float):
assert module.normalized_shape[0] == module.weight.numel()
assert module.normalized_shape[0] == module.bias.numel()
q_module = Int8PhiLayerNorm(module.normalized_shape[0], module.eps)
q_module.weight = nn.Parameter(module.weight / output_scale)
q_module.bias = nn.Parameter(module.bias / output_scale)
return q_module
class Int8PhiMLP(nn.Module):
def __init__(self, config, quant_config: dict[str, str]):
super().__init__()
self.config = config
self.activation_fn = ACT2FN[config.hidden_act]
self.fc1_quant_type = quant_config["fc1"]
self.fc2_quant_type = quant_config["fc2"]
self.fc1 = W8A8BFP32OFP32Linear(config.hidden_size, config.intermediate_size, act_quant=self.fc1_quant_type)
self.fc2 = W8A8BFP32OFP32LinearWithQuantScale(config.intermediate_size, config.hidden_size,act_quant=self.fc2_quant_type)

forward = PhiMLP.forward

@staticmethod
@torch.no_grad()
def from_float(module: PhiMLP,
config: PhiConfig,
quant_config: dict[str, str],
fc1_input_scale: float,
fc2_input_scale: float):
int8_module = Int8PhiMLP(config, quant_config)
int8_module.fc1 = W8A8BFP32OFP32Linear.from_float(
module.fc1, fc1_input_scale, act_quant=int8_module.fc1_quant_type)
int8_module.fc2 = W8A8BFP32OFP32LinearWithQuantScale.from_float(
module.fc2, fc2_input_scale, act_quant=int8_module.fc2_quant_type)
return int8_module

class Int8PhiAttention(nn.Module):
"""Multi-headed attention from 'Attention Is All You Need' paper"""

def __init__(self, config: PhiConfig, quant_config: dict[str, str], layer_idx: Optional[int] = None):
super().__init__()
self.config = config
self.layer_idx = layer_idx
if layer_idx is None:
logger.warning_once(
f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will "
"lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` "
"when creating this class."
)

self.attention_dropout = config.attention_dropout
self.hidden_size = config.hidden_size
self.num_heads = config.num_attention_heads
self.head_dim = self.hidden_size // self.num_heads
self.num_key_value_heads = config.num_key_value_heads
self.num_key_value_groups = self.num_heads // self.num_key_value_heads
self.max_position_embeddings = config.max_position_embeddings
self.rope_theta = config.rope_theta
self.partial_rotary_factor = config.partial_rotary_factor
self.is_causal = True

if (self.head_dim * self.num_heads) != self.hidden_size:
raise ValueError(
f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
f" and `num_heads`: {self.num_heads})."
)

self.qkv_quant_type = quant_config["qkv"]
self.o_quant_type = quant_config["out"]
self.q_proj = W8A8BFP32OFP32Linear(self.hidden_size, self.num_heads * self.head_dim, use_bias=True, act_quant=self.qkv_quant_type)
self.k_proj = W8A8BFP32OFP32Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, use_bias=True, act_quant=self.qkv_quant_type)
self.v_proj = W8A8BFP32OFP32Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, use_bias=True, act_quant=self.qkv_quant_type)
self.dense = W8A8BFP32OFP32LinearWithQuantScale(self.num_heads * self.head_dim, self.hidden_size, use_bias=True, act_quant=self.o_quant_type)

self.qk_layernorm = config.qk_layernorm
# false
if self.qk_layernorm:
self.q_layernorm = nn.LayerNorm(
config.hidden_size // self.num_heads, eps=config.layer_norm_eps, elementwise_affine=True
)
self.k_layernorm = nn.LayerNorm(
config.hidden_size // self.num_heads, eps=config.layer_norm_eps, elementwise_affine=True
)

self._init_rope()

_init_rope = PhiAttention._init_rope
forward = PhiAttention.forward

@staticmethod
@torch.no_grad()
def from_float(module: PhiAttention,
config: PhiConfig,
quant_config: dict[str, str],
attn_input_scale: float,
q_output_scale: float,
k_output_scale: float,
v_output_scale: float,
dense_input_scale: float):
int8_module = Int8PhiAttention(config, quant_config)
# we do not impelement attn for now bacuase we want use paged attention
int8_module.q_proj = W8A8BFP32OFP32Linear.from_float(module.q_proj, attn_input_scale, act_quant=int8_module.qkv_quant_type)
int8_module.k_proj = W8A8BFP32OFP32Linear.from_float(module.k_proj, attn_input_scale, act_quant=int8_module.qkv_quant_type)
int8_module.v_proj = W8A8BFP32OFP32Linear.from_float(module.v_proj, attn_input_scale, act_quant=int8_module.qkv_quant_type)
int8_module.dense = W8A8BFP32OFP32LinearWithQuantScale.from_float(
module.dense, dense_input_scale, act_quant=int8_module.o_quant_type)
return int8_module
class Int8PhiDecoderLayer(nn.Module):
def __init__(self, config: PhiConfig, quant_config: dict[str, str], layer_idx: int):
super().__init__()
self.self_attn = Int8PhiAttention(config, quant_config, layer_idx=layer_idx)
self.mlp = Int8PhiMLP(config, quant_config)
self.input_layernorm = Int8PhiLayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.resid_dropout = nn.Dropout(config.resid_pdrop)

forward = PhiDecoderLayer.forward

@staticmethod
def from_float(module: PhiDecoderLayer,
config: PhiConfig,
quant_config: dict[str, str],
attn_input_scale: float,
q_output_scale: float,
k_output_scale: float,
v_output_scale: float,
dense_input_scale: float,
fc1_input_scale: float,
fc2_input_scale: float
):
int8_module = Int8PhiDecoderLayer(
config,
quant_config,
module.self_attn.layer_idx
)
int8_module.self_attn = Int8PhiAttention.from_float(
module.self_attn,
config,
quant_config,
attn_input_scale,
q_output_scale,
k_output_scale,
v_output_scale,
dense_input_scale
)
int8_module.mlp = Int8PhiMLP.from_float(
module.mlp,
config,
quant_config,
fc1_input_scale,
fc2_input_scale
)
if quant_config["qkv"] == "per-tensor":
int8_module.input_layernorm = Int8PhiLayerNorm.from_float(
module.input_layernorm,
attn_input_scale
)
else:
int8_module.input_layernorm = module.input_layernorm

return int8_module
class Int8PhiModel(PhiPreTrainedModel):
"""
Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`PhiDecoderLayer`]

Args:
config: PhiConfig
"""

def __init__(self, config: PhiConfig, quant_config: dict[str, str]):
super().__init__(config)
self.config = config
self.padding_idx = config.pad_token_id
self.vocab_size = config.vocab_size

self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
self.embed_dropout = nn.Dropout(config.embd_pdrop)
self.layers = nn.ModuleList(
[Int8PhiDecoderLayer(config, quant_config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
)
self.final_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)

self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2"
self._use_sdpa = config._attn_implementation == "sdpa"

self.gradient_checkpointing = False
# Initialize weights and apply final processing
self.post_init()

get_input_embeddings = PhiModel.get_input_embeddings
set_input_embeddings = PhiModel.set_input_embeddings
forward = PhiModel.forward

@staticmethod
def from_float(module, decoder_layer_scales, quant_config):
int8_module = Int8PhiModel(module.config, quant_config)

int8_module.embed_tokens = module.embed_tokens
int8_module.final_layernorm = module.final_layernorm

for i, layer in enumerate(module.layers):
int8_module.layers[i] = Int8PhiDecoderLayer.from_float(
layer, module.config, quant_config, **decoder_layer_scales[i])
return int8_module
class Int8PhiForCausalLM(PhiPreTrainedModel):
_tied_weights_keys = ["lm_head.weight"]

def __init__(self, config, quant_config):
super().__init__(config)
self.model = Int8PhiModel(config, quant_config)
self.vocab_size = config.vocab_size
# no need to quant
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=True)

# Initialize weights and apply final processing
self.post_init()

get_input_embeddings = PhiForCausalLM.get_input_embeddings
set_input_embeddings = PhiForCausalLM.set_input_embeddings
get_output_embeddings = PhiForCausalLM.get_output_embeddings
set_output_embeddings = PhiForCausalLM.set_output_embeddings
set_decoder = PhiForCausalLM.set_decoder
get_decoder = PhiForCausalLM.get_decoder
forward = PhiForCausalLM.forward
prepare_inputs_for_generation = PhiForCausalLM.prepare_inputs_for_generation
_reorder_cache = PhiForCausalLM._reorder_cache

@staticmethod
def from_float(module, decoder_layer_scales, quant_config):
int8_module = Int8PhiForCausalLM(module.config, quant_config)
print("start trans into int8, this might take a while")
int8_module.model = Int8PhiModel.from_float(
module.model, decoder_layer_scales, quant_config)
int8_module.lm_head = module.lm_head
return int8_module
Loading