44import torch
55import random
66import argparse
7+
8+ from bit_decode import DynamicCache , StaticCache , Cache
9+ import transformers .cache_utils
10+ transformers .cache_utils .DynamicCache = DynamicCache
11+ transformers .cache_utils .StaticCache = StaticCache
12+ transformers .cache_utils .Cache = Cache
13+
714from llama import LlamaForCausalLM
8- from transformers import LlamaConfig , AutoTokenizer
15+ from qwen3 import Qwen3ForCausalLM
16+ from transformers import LlamaConfig , Qwen3Config , AutoTokenizer
917from datasets import load_dataset
1018
1119def main ():
@@ -23,21 +31,34 @@ def main():
2331 random .seed (0 )
2432 torch .manual_seed (0 )
2533
26- config = LlamaConfig .from_pretrained (args .model_path )
34+ if "Llama" in args .model_path :
35+ config = LlamaConfig .from_pretrained (args .model_path )
36+ elif "Qwen" in args .model_path :
37+ config = Qwen3Config .from_pretrained (args .model_path )
2738
39+ config ._attn_implementation = "flash_attention_2"
2840 config .attn_backend = args .attn_backend
2941 config .num_bits = args .num_bits
3042 config .quant_mode = args .quant_mode
3143 config .group_size = args .group_size
3244 config .residual_block_size = 128 if args .num_bits == 4 else 256
3345
34- model = LlamaForCausalLM .from_pretrained (
35- pretrained_model_name_or_path = args .model_path ,
36- config = config ,
37- low_cpu_mem_usage = True ,
38- torch_dtype = torch .float16 ,
39- device_map = "auto"
40- )
46+ if "Llama" in args .model_path :
47+ model = LlamaForCausalLM .from_pretrained (
48+ pretrained_model_name_or_path = args .model_path ,
49+ config = config ,
50+ low_cpu_mem_usage = True ,
51+ torch_dtype = torch .float16 ,
52+ device_map = "auto"
53+ )
54+ elif "Qwen" in args .model_path :
55+ model = Qwen3ForCausalLM .from_pretrained (
56+ pretrained_model_name_or_path = args .model_path ,
57+ config = config ,
58+ low_cpu_mem_usage = True ,
59+ torch_dtype = torch .float16 ,
60+ device_map = "auto"
61+ )
4162
4263 enc = AutoTokenizer .from_pretrained (
4364 args .model_path ,
@@ -71,7 +92,7 @@ def main():
7192 )
7293 config_str = f"# prompt tokens: { inputs .input_ids .shape [1 ]} "
7394
74- print (prompt + "\n " + "=" * 10 + f'\n { config_str } \n ' + "=" * 10 + "\n Output:" )
95+ # print(prompt + "\n" + "=" * 10 + f'\n{config_str}\n' + "=" * 10 + "\nOutput:")
7596 # print("\n" + "=" * 10 + f'\n{config_str}\n' + "=" * 10 + "\nOutput:")
7697 print (enc .decode (output [0 ].tolist ()[inputs .input_ids .shape [1 ]:], skip_special_tokens = True ))
7798
0 commit comments