Skip to content

Commit 9981c1d

Browse files
committed
update qwen3
1 parent e324766 commit 9981c1d

7 files changed

Lines changed: 2292 additions & 396 deletions

File tree

bit_decode/models/cache_utils.py

Lines changed: 740 additions & 376 deletions
Large diffs are not rendered by default.

evaluation/example.py

Lines changed: 31 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,16 @@
44
import torch
55
import random
66
import argparse
7+
8+
from bit_decode import DynamicCache, StaticCache, Cache
9+
import transformers.cache_utils
10+
transformers.cache_utils.DynamicCache = DynamicCache
11+
transformers.cache_utils.StaticCache = StaticCache
12+
transformers.cache_utils.Cache = Cache
13+
714
from llama import LlamaForCausalLM
8-
from transformers import LlamaConfig, AutoTokenizer
15+
from qwen3 import Qwen3ForCausalLM
16+
from transformers import LlamaConfig, Qwen3Config, AutoTokenizer
917
from datasets import load_dataset
1018

1119
def main():
@@ -23,21 +31,34 @@ def main():
2331
random.seed(0)
2432
torch.manual_seed(0)
2533

26-
config = LlamaConfig.from_pretrained(args.model_path)
34+
if "Llama" in args.model_path:
35+
config = LlamaConfig.from_pretrained(args.model_path)
36+
elif "Qwen" in args.model_path:
37+
config = Qwen3Config.from_pretrained(args.model_path)
2738

39+
config._attn_implementation = "flash_attention_2"
2840
config.attn_backend = args.attn_backend
2941
config.num_bits = args.num_bits
3042
config.quant_mode = args.quant_mode
3143
config.group_size = args.group_size
3244
config.residual_block_size = 128 if args.num_bits == 4 else 256
3345

34-
model = LlamaForCausalLM.from_pretrained(
35-
pretrained_model_name_or_path=args.model_path,
36-
config=config,
37-
low_cpu_mem_usage=True,
38-
torch_dtype=torch.float16,
39-
device_map="auto"
40-
)
46+
if "Llama" in args.model_path:
47+
model = LlamaForCausalLM.from_pretrained(
48+
pretrained_model_name_or_path=args.model_path,
49+
config=config,
50+
low_cpu_mem_usage=True,
51+
torch_dtype=torch.float16,
52+
device_map="auto"
53+
)
54+
elif "Qwen" in args.model_path:
55+
model = Qwen3ForCausalLM.from_pretrained(
56+
pretrained_model_name_or_path=args.model_path,
57+
config=config,
58+
low_cpu_mem_usage=True,
59+
torch_dtype=torch.float16,
60+
device_map="auto"
61+
)
4162

4263
enc = AutoTokenizer.from_pretrained(
4364
args.model_path,
@@ -71,7 +92,7 @@ def main():
7192
)
7293
config_str = f"# prompt tokens: {inputs.input_ids.shape[1]}"
7394

74-
print(prompt + "\n" + "=" * 10 + f'\n{config_str}\n' + "=" * 10 + "\nOutput:")
95+
# print(prompt + "\n" + "=" * 10 + f'\n{config_str}\n' + "=" * 10 + "\nOutput:")
7596
# print("\n" + "=" * 10 + f'\n{config_str}\n' + "=" * 10 + "\nOutput:")
7697
print(enc.decode(output[0].tolist()[inputs.input_ids.shape[1]:], skip_special_tokens=True))
7798

evaluation/llama.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -532,8 +532,15 @@ def forward(
532532
value_states = value_states.transpose(1, 2)
533533

534534
if q_len == 1:
535+
print("query_states1: ", query_states.shape)
536+
print("key_states1: ", key_states.shape)
537+
print("value_states1: ", value_states.shape)
538+
535539
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx)
536540

541+
print("query_states2: ", query_states.shape)
542+
print("key_states2: ", key_states.shape)
543+
print("value_states2: ", value_states.shape)
537544
attn_output = flash_attn_with_kvcache(
538545
query_states,
539546
key_states,
@@ -554,6 +561,7 @@ def forward(
554561
is_causal=self.is_causal,
555562
**kwargs,
556563
)
564+
557565
attn_output = attn_output.reshape(bsz, q_len, -1).contiguous()
558566
attn_output = self.o_proj(attn_output)
559567

evaluation/print.log1

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
2+
Answer: Arnel shared 5 x 8 = <<5*8=40>>40 pencils with his friends.
3+
He had 40 + 10 = <<40+10=50>>50 pencils in total.
4+
Since there are ten boxes of pencils, there are 50 / 10 = <<50/10=5>>5 pencils in each box.
5+
#### 5
6+
Question: A bakery has 600 cups of flour. If they use 12 cups of flour for every batch of bread and 8 cups of flour for every

0 commit comments

Comments
 (0)