|
9 | 9 |
|
10 | 10 | import numpy as np |
11 | 11 | import torch |
12 | | -from transformers import AutoTokenizer |
| 12 | +from transformers import AutoConfig, AutoTokenizer |
13 | 13 |
|
14 | 14 | from QEfficient import QEFFAutoModelForCausalLM |
15 | 15 | from QEfficient.generation.cloud_infer import QAICInferenceSession |
16 | 16 |
|
17 | | -model_id = "openai/gpt-oss-120b" # weights are not required to convert to fp32 |
| 17 | +model_id = "openai/gpt-oss-20b" # weights are not required to convert to fp32 |
18 | 18 |
|
19 | 19 | prompt = """ |
20 | 20 | Once upon a time, in a small town, there lived a young boy named Alex. Alex was a curious and adventurous child, always eager to explore the world around him. One day, while playing in the park, Alex stumbled upon a mysterious old book hidden beneath a pile of leaves. The book was filled with stories of distant lands, magical creatures, and extraordinary adventures. |
|
23 | 23 |
|
24 | 24 | The path to the treasure was not an easy one. Alex had to navigate through dense forests, cross rickety bridges, and solve riddles that guarded the treasure's location. |
25 | 25 | """ |
26 | | -all_outputs = [] |
27 | 26 | # Run prefill |
| 27 | +config = AutoConfig.from_pretrained(model_id) |
28 | 28 | tokenizer = AutoTokenizer.from_pretrained(model_id) |
29 | 29 | PREFILL_SEQ_LEN = 128 |
30 | | -CTX_LEN = 2 * 128 |
31 | | -inputs = tokenizer(prompt, return_tensors="np", padding=True) |
32 | | -position_ids = inputs["attention_mask"].sum(1, keepdims=True) |
33 | | -padded_len = inputs["input_ids"].shape[1] |
34 | | -num_chunks = -(padded_len // -PREFILL_SEQ_LEN) # ceil divide without float |
35 | | -padded_len = num_chunks * PREFILL_SEQ_LEN # Convert to a multiple of prompt_len |
| 30 | +CTX_LEN = 128 * 3 |
36 | 31 |
|
37 | | -# Initialize variables specific to request |
38 | | -# Calculate the max generation length. |
39 | | -max_gen_len = CTX_LEN - position_ids.max() |
40 | | -generation_len = max_gen_len |
41 | | - |
42 | | - |
43 | | -# qeff_model = QEFFAutoModelForCausalLM.from_pretrained(model_id, num_hidden_layers=2) |
44 | 32 | qeff_model = QEFFAutoModelForCausalLM.from_pretrained(model_id) |
45 | 33 |
|
46 | | - |
47 | 34 | decode_qpc_path = qeff_model.compile( |
48 | 35 | prefill_seq_len=1, |
49 | 36 | ctx_len=CTX_LEN, |
|
55 | 42 | aic_enable_depth_first=True, |
56 | 43 | num_speculative_tokens=None, |
57 | 44 | offload_pt_weights=False, # Need the weights in memory for prefill-model export/compilation in the next step |
| 45 | + retain_full_kv=True, |
58 | 46 | ) |
59 | 47 |
|
60 | | -config = qeff_model.model.config |
61 | | -inputs = tokenizer(prompt, return_tensors="np", padding="max_length", max_length=padded_len) |
62 | | -inputs["position_ids"] = np.where(inputs.pop("attention_mask"), np.arange(padded_len), -1) |
63 | | -inputs.pop("token_type_ids", None) |
64 | | -inputs = {k: torch.from_numpy(v) for k, v in inputs.items()} |
65 | | -past_key_values = [] |
66 | | -for i in range(config.num_hidden_layers): |
67 | | - cache_len = config.sliding_window if i % 2 == 0 else PREFILL_SEQ_LEN |
68 | | - pad_shape = (1, 8, cache_len, 64) |
69 | | - past_key = torch.zeros((pad_shape), dtype=torch.float32) |
70 | | - past_value = torch.zeros((pad_shape), dtype=torch.float32) |
71 | | - pkv = (past_key, past_value) |
72 | | - past_key_values.append(pkv) |
73 | | -inputs["past_key_values"] = past_key_values |
74 | 48 |
|
| 49 | +# Following command errors out by default, the user is supposed to run the printed command and provide the generated qpc path as prefill_qpc_path commenting out lines 55-68 |
| 50 | +# prefill_qpc_path = "provide path here" |
75 | 51 | prefill_qpc_path = qeff_model.compile( |
76 | 52 | prefill_seq_len=PREFILL_SEQ_LEN, |
77 | 53 | ctx_len=CTX_LEN, |
|
85 | 61 | prefill_only=True, |
86 | 62 | enable_chunking=True, |
87 | 63 | use_onnx_subfunctions=True, |
88 | | - offload_pt_weights=False, |
89 | | -) |
90 | | -print("loading qpc") |
91 | | -st = time.time() |
92 | | -prefill_session = QAICInferenceSession(prefill_qpc_path, device_ids=[i for i in range(32, 48)]) |
93 | | -print(f"time for loading session = {time.time() - st}") |
94 | | -print("done") |
95 | | -prefill_session.skip_buffers( |
96 | | - [x for x in prefill_session.input_names + prefill_session.output_names if x.startswith("past_")] |
97 | 64 | ) |
98 | | -logits_out_placeholder = np.zeros((1, 1, 201088), dtype=np.float32) |
99 | | -prefill_session.set_buffers({"logits": logits_out_placeholder}) |
100 | | -inputs.pop("past_key_values") |
| 65 | + |
| 66 | + |
| 67 | +inputs = tokenizer(prompt, return_tensors="np", padding=True) |
| 68 | +position_ids = inputs["attention_mask"].sum(1, keepdims=True) |
| 69 | +generation_len = CTX_LEN - position_ids.max() |
| 70 | +padded_len = inputs["input_ids"].shape[1] |
| 71 | +num_chunks = -(padded_len // -PREFILL_SEQ_LEN) # ceil divide without float |
| 72 | +padded_len = num_chunks * PREFILL_SEQ_LEN # Convert to a multiple of prompt_len |
| 73 | +inputs = tokenizer(prompt, return_tensors="np", padding="max_length", max_length=padded_len) |
| 74 | +inputs["position_ids"] = np.where(inputs.pop("attention_mask"), np.arange(padded_len), -1) |
| 75 | +inputs.pop("token_type_ids", None) |
| 76 | +inputs = {k: torch.from_numpy(v) for k, v in inputs.items()} |
| 77 | +inputs.pop("past_key_values", None) |
101 | 78 | inputs = {k: v.detach().numpy() for k, v in inputs.items()} |
102 | | -st = time.time() |
103 | 79 |
|
| 80 | + |
| 81 | +decode_session = QAICInferenceSession(decode_qpc_path) |
| 82 | +prefill_session = QAICInferenceSession(prefill_qpc_path) |
| 83 | + |
| 84 | +all_outputs = [] |
104 | 85 | for i in range(num_chunks): |
105 | 86 | chunk_inputs = inputs.copy() |
106 | 87 | chunk_inputs["input_ids"] = inputs["input_ids"][:, i * PREFILL_SEQ_LEN : (i + 1) * PREFILL_SEQ_LEN] |
107 | 88 | chunk_inputs["position_ids"] = inputs["position_ids"][:, i * PREFILL_SEQ_LEN : (i + 1) * PREFILL_SEQ_LEN] |
108 | 89 | ins = time.time() |
109 | 90 | qpc_out = prefill_session.run(chunk_inputs) |
110 | 91 | print(f"time for this run={time.time() - ins}") |
111 | | -print(f"time for prefill_run={time.time() - st} sec\n") |
112 | | - |
113 | | -decode_session = QAICInferenceSession(decode_qpc_path) |
114 | | -decode_session.set_buffers({"logits": logits_out_placeholder}) |
| 92 | + for i in range(config.num_hidden_layers): |
| 93 | + inputs[f"past_key.{i}"] = qpc_out[f"past_key.{i}_RetainedState"] |
| 94 | + inputs[f"past_value.{i}"] = qpc_out[f"past_value.{i}_RetainedState"] |
115 | 95 |
|
| 96 | +all_outputs.append(np.argmax(qpc_out["logits"])) |
116 | 97 | decode_inputs = { |
117 | 98 | "input_ids": np.argmax(qpc_out["logits"]).reshape(1, 1), |
118 | 99 | "position_ids": np.max(inputs["position_ids"]).reshape(1, 1) + 1, |
119 | 100 | } |
120 | | -print("pos_id for decodee", decode_inputs["position_ids"]) |
121 | | - |
122 | | -all_outputs.append(decode_inputs["input_ids"][0][0]) |
123 | 101 | for i in range(config.num_hidden_layers): |
124 | | - if i % 2 == 0 and decode_inputs["position_ids"] >= config.sliding_window: |
125 | | - last_valid_pos_idx = decode_inputs["position_ids"][0][0] |
126 | | - first_valid_pos_idx = last_valid_pos_idx - config.sliding_window |
127 | | - k = qpc_out[f"past_key.{i}_RetainedState"][:, :, first_valid_pos_idx:last_valid_pos_idx, :] |
128 | | - v = qpc_out[f"past_value.{i}_RetainedState"][:, :, first_valid_pos_idx:last_valid_pos_idx, :] |
129 | | - mod_pos_id = config.sliding_window - decode_inputs["position_ids"][0][0] % config.sliding_window |
130 | | - decode_inputs[f"past_key.{i}"] = np.concatenate((k[:, :, mod_pos_id:, :], k[:, :, :mod_pos_id, :]), axis=-2) |
131 | | - decode_inputs[f"past_value.{i}"] = np.concatenate((v[:, :, mod_pos_id:, :], v[:, :, :mod_pos_id, :]), axis=-2) |
132 | | - else: |
133 | | - decode_inputs[f"past_key.{i}"] = qpc_out[f"past_key.{i}_RetainedState"] |
134 | | - decode_inputs[f"past_value.{i}"] = qpc_out[f"past_value.{i}_RetainedState"] |
| 102 | + decode_inputs[f"past_key.{i}"] = qpc_out[f"past_key.{i}_RetainedState"] |
| 103 | + decode_inputs[f"past_value.{i}"] = qpc_out[f"past_value.{i}_RetainedState"] |
135 | 104 |
|
136 | 105 | st = time.time() |
137 | 106 | decode_out = decode_session.run(decode_inputs) |
138 | 107 | print(f"time for first run of decode with KV as input = {time.time() - st} sec\n") |
139 | | -decode_session.skip_buffers( |
140 | | - [x for x in decode_session.input_names + decode_session.output_names if x.startswith("past_")] |
141 | | -) |
| 108 | +all_outputs.append(np.argmax(decode_out["logits"])) |
142 | 109 | pos_id = np.max(decode_inputs["position_ids"]).reshape(1, 1) + 1 |
| 110 | +loop_decode_inputs = { |
| 111 | + "input_ids": np.argmax(decode_out["logits"]).reshape(1, 1), |
| 112 | + "position_ids": pos_id, |
| 113 | +} |
| 114 | + |
| 115 | +for i in range(config.num_hidden_layers): |
| 116 | + loop_decode_inputs[f"past_key.{i}"] = decode_out[f"past_key.{i}_RetainedState"] |
| 117 | + loop_decode_inputs[f"past_value.{i}"] = decode_out[f"past_value.{i}_RetainedState"] |
| 118 | + |
143 | 119 | st = time.time() |
144 | 120 | for i in range(generation_len - 2): |
145 | | - loop_decode_inputs = { |
146 | | - "input_ids": np.argmax(decode_out["logits"]).reshape(1, 1), |
147 | | - "position_ids": pos_id, |
148 | | - } |
149 | | - all_outputs.append(loop_decode_inputs["input_ids"][0][0]) |
150 | 121 | decode_out = decode_session.run(loop_decode_inputs) |
| 122 | + all_outputs.append(np.argmax(decode_out["logits"])) |
151 | 123 | pos_id += 1 |
152 | | - |
153 | | - |
154 | | -print(f"time for decode generation = {(time.time() - st) / (generation_len - 2)}") |
155 | | -print(all_outputs) |
156 | | -print(tokenizer.decode(all_outputs)) |
| 124 | + for i in range(config.num_hidden_layers): |
| 125 | + loop_decode_inputs[f"past_key.{i}"] = decode_out[f"past_key.{i}_RetainedState"] |
| 126 | + loop_decode_inputs[f"past_value.{i}"] = decode_out[f"past_value.{i}_RetainedState"] |
| 127 | + |
| 128 | + loop_decode_inputs.update( |
| 129 | + { |
| 130 | + "input_ids": np.argmax(decode_out["logits"]).reshape(1, 1), |
| 131 | + "position_ids": pos_id, |
| 132 | + } |
| 133 | + ) |
| 134 | +ft = time.time() |
| 135 | + |
| 136 | +print(f"decode tok/sec={(generation_len - 2) / (ft - st)}") |
| 137 | +print(f"input\n{prompt}\noutput\n{tokenizer.decode(all_outputs)}") |
0 commit comments