Skip to content
This repository was archived by the owner on Oct 25, 2024. It is now read-only.

Commit a4aba8d

Browse files
authored
[LLM example] add calib_shuffle args for text-generation example (#1087)
Signed-off-by: Wang, Chang1 <chang1.wang@intel.com>
1 parent e8170aa commit a4aba8d

File tree

5 files changed

+24
-6
lines changed

5 files changed

+24
-6
lines changed

examples/huggingface/pytorch/text-generation/quantization/run_generation.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
AutoModel,
1212
)
1313
from transformers.utils import check_min_version
14+
from intel_extension_for_transformers.transformers.utils import str2bool
1415
from optimum.intel.generation.modeling import TSModelForCausalLM
1516
from intel_extension_for_transformers.transformers import (
1617
MixedPrecisionConfig,
@@ -67,6 +68,12 @@
6768
parser.add_argument(
6869
"--calib_padding", action="store_true", help="Calibration dataset do padding."
6970
)
71+
parser.add_argument(
72+
"--calib_shuffle",
73+
default=True,
74+
type=str2bool,
75+
help="Calibration dataset do shuffle.",
76+
)
7077
parser.add_argument(
7178
"--calib_pad_val", default=1, type=int, help="Calibration dataset padding value."
7279
)
@@ -126,16 +133,14 @@
126133
parser.add_argument("--load_in_4bit", type=bool, default=False)
127134
parser.add_argument("--load_in_8bit", type=bool, default=False)
128135
parser.add_argument("--_commit_hash", default="main", type=str)
129-
parser.add_argument("--trust_remote_code", default=False)
136+
parser.add_argument("--trust_remote_code", type=bool, default=False)
130137
parser.add_argument("--use_llm_runtime", action="store_true")
131138
# =======================================
132139
args = parser.parse_args()
133-
134140
# transformers version >= 4.32.0 contained the mpt modeling definition.
135141
# https://github.com/huggingface/transformers/blob/main/src/transformers/models/mpt/modeling_mpt.py
136142
# 4.31.0 for ipex.optimize_transformers
137143
check_min_version("4.31.0")
138-
139144
# get model config
140145
if args.peft_model_id:
141146
from peft import PeftConfig
@@ -228,6 +233,7 @@
228233
op_type_dict=op_type_dict, # default is {}
229234
excluded_precisions=excluded_precisions, # default is []
230235
num_beams=generate_kwargs["num_beams"],
236+
calib_shuffle=args.calib_shuffle,
231237
calib_iters=args.calib_iters,
232238
calib_padding=args.calib_padding,
233239
calib_len=args.calib_len,
@@ -257,7 +263,6 @@
257263
trust_remote_code=args.trust_remote_code,
258264
_commit_hash=args._commit_hash,
259265
use_llm_runtime=args.use_llm_runtime,
260-
261266
)
262267
elif args.load_in_4bit or args.load_in_8bit:
263268
# CPU device usage is provided by intel-extension-for-transformers.

intel_extension_for_transformers/transformers/modeling/modeling_auto.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -380,6 +380,7 @@ def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
380380
from torch.utils.data import DataLoader
381381

382382
calib_dataset = quantization_config.calib_dataset
383+
calib_shuffle = quantization_config.calib_shuffle
383384
calib_iters = quantization_config.calib_iters
384385
calib_padding = quantization_config.calib_padding
385386
calib_len = quantization_config.calib_len
@@ -392,7 +393,8 @@ def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
392393
if calib_dataset in ["mbpp", "openai_humaneval"]
393394
else "train",
394395
)
395-
calib_dataset = calib_dataset.shuffle(seed=42)
396+
if calib_shuffle:
397+
calib_dataset = calib_dataset.shuffle(seed=42)
396398

397399
def tokenize_function(examples):
398400
if "prompt" in examples:

intel_extension_for_transformers/transformers/utils/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,4 +24,4 @@
2424
SparsityConfig,
2525
WeightOnlyQuantConfig,
2626
)
27-
from .utility import LazyImport, logger
27+
from .utility import LazyImport, logger, str2bool

intel_extension_for_transformers/transformers/utils/config.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -390,6 +390,7 @@ class SmoothQuantConfig:
390390
tokenizer: Any = None
391391
calib_func: Any = None
392392
calib_dataset: str = "NeelNanda/pile-10k"
393+
calib_shuffle: bool = True
393394
calib_iters: int = 100
394395
calib_padding: bool = False
395396
calib_len: int = 512

intel_extension_for_transformers/transformers/utils/utility.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717

1818
"""Utils for pytorch framework."""
1919

20+
import argparse
2021
import os
2122
from typing import Optional, Tuple
2223
from neural_compressor.utils import logger
@@ -36,6 +37,15 @@
3637

3738
torch = LazyImport("torch")
3839

40+
def str2bool(v):
41+
if isinstance(v, bool):
42+
return v
43+
if v.lower() in ('yes', 'true', 't', 'y', '1'):
44+
return True
45+
elif v.lower() in ('no', 'false', 'f', 'n', '0'):
46+
return False
47+
else:
48+
raise argparse.ArgumentTypeError('Boolean value expected.')
3949

4050
def distributed_init(
4151
backend="gloo",

0 commit comments

Comments
 (0)