|
| 1 | +# ----------------------------------------------------------------------------- |
| 2 | +# |
| 3 | +# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. |
| 4 | +# SPDX-License-Identifier: BSD-3-Clause |
| 5 | +# |
| 6 | +# ----------------------------------------------------------------------------- |
| 7 | + |
| 8 | +import time |
| 9 | + |
| 10 | +import numpy as np |
| 11 | +import torch |
| 12 | +from transformers import AutoTokenizer |
| 13 | + |
| 14 | +from QEfficient import QEFFAutoModelForCausalLM |
| 15 | +from QEfficient.generation.cloud_infer import QAICInferenceSession |
| 16 | + |
| 17 | +model_id = "openai/gpt-oss-120b" # weights are not required to convert to fp32 |
| 18 | + |
| 19 | +prompt = """ |
| 20 | +Once upon a time, in a small town, there lived a young boy named Alex. Alex was a curious and adventurous child, always eager to explore the world around him. One day, while playing in the park, Alex stumbled upon a mysterious old book hidden beneath a pile of leaves. The book was filled with stories of distant lands, magical creatures, and extraordinary adventures. |
| 21 | +
|
| 22 | +As Alex flipped through the pages, he discovered a map that led to a hidden treasure. Excited by the prospect of a real-life treasure hunt, Alex decided to embark on a thrilling journey. He packed his backpack with snacks, a flashlight, and a compass, and set off into the unknown. |
| 23 | +
|
| 24 | +The path to the treasure was not an easy one. Alex had to navigate through dense forests, cross rickety bridges, and solve riddles that guarded the treasure's location. |
| 25 | +""" |
| 26 | +all_outputs = [] |
| 27 | +# Run prefill |
| 28 | +tokenizer = AutoTokenizer.from_pretrained(model_id) |
| 29 | +PREFILL_SEQ_LEN = 128 |
| 30 | +CTX_LEN = 2 * 128 |
| 31 | +inputs = tokenizer(prompt, return_tensors="np", padding=True) |
| 32 | +position_ids = inputs["attention_mask"].sum(1, keepdims=True) |
| 33 | +padded_len = inputs["input_ids"].shape[1] |
| 34 | +num_chunks = -(padded_len // -PREFILL_SEQ_LEN) # ceil divide without float |
| 35 | +padded_len = num_chunks * PREFILL_SEQ_LEN # Convert to a multiple of prompt_len |
| 36 | + |
| 37 | +# Initialize variables specific to request |
| 38 | +# Calculate the max generation length. |
| 39 | +max_gen_len = CTX_LEN - position_ids.max() |
| 40 | +generation_len = max_gen_len |
| 41 | + |
| 42 | + |
| 43 | +# qeff_model = QEFFAutoModelForCausalLM.from_pretrained(model_id, num_hidden_layers=2) |
| 44 | +qeff_model = QEFFAutoModelForCausalLM.from_pretrained(model_id) |
| 45 | + |
| 46 | + |
| 47 | +decode_qpc_path = qeff_model.compile( |
| 48 | + prefill_seq_len=1, |
| 49 | + ctx_len=CTX_LEN, |
| 50 | + num_cores=16, |
| 51 | + mxfp6_matmul=True, |
| 52 | + mxint8_kv_cache=True, |
| 53 | + num_devices=1, |
| 54 | + mos=1, |
| 55 | + aic_enable_depth_first=True, |
| 56 | + num_speculative_tokens=None, |
| 57 | + offload_pt_weights=False, # Need the weights in memory for prefill-model export/compilation in the next step |
| 58 | +) |
| 59 | + |
| 60 | +config = qeff_model.model.config |
| 61 | +inputs = tokenizer(prompt, return_tensors="np", padding="max_length", max_length=padded_len) |
| 62 | +inputs["position_ids"] = np.where(inputs.pop("attention_mask"), np.arange(padded_len), -1) |
| 63 | +inputs.pop("token_type_ids", None) |
| 64 | +inputs = {k: torch.from_numpy(v) for k, v in inputs.items()} |
| 65 | +past_key_values = [] |
| 66 | +for i in range(config.num_hidden_layers): |
| 67 | + cache_len = config.sliding_window if i % 2 == 0 else PREFILL_SEQ_LEN |
| 68 | + pad_shape = (1, 8, cache_len, 64) |
| 69 | + past_key = torch.zeros((pad_shape), dtype=torch.float32) |
| 70 | + past_value = torch.zeros((pad_shape), dtype=torch.float32) |
| 71 | + pkv = (past_key, past_value) |
| 72 | + past_key_values.append(pkv) |
| 73 | +inputs["past_key_values"] = past_key_values |
| 74 | + |
| 75 | +prefill_qpc_path = qeff_model.compile( |
| 76 | + prefill_seq_len=PREFILL_SEQ_LEN, |
| 77 | + ctx_len=CTX_LEN, |
| 78 | + num_cores=16, |
| 79 | + mxfp6_matmul=True, |
| 80 | + mxint8_kv_cache=True, |
| 81 | + num_devices=1, |
| 82 | + mos=1, |
| 83 | + aic_enable_depth_first=True, |
| 84 | + num_speculative_tokens=None, |
| 85 | + prefill_only=True, |
| 86 | + enable_chunking=True, |
| 87 | + use_onnx_subfunctions=True, |
| 88 | + offload_pt_weights=False, |
| 89 | +) |
| 90 | +print("loading qpc") |
| 91 | +st = time.time() |
| 92 | +prefill_session = QAICInferenceSession(prefill_qpc_path, device_ids=[i for i in range(32, 48)]) |
| 93 | +print(f"time for loading session = {time.time() - st}") |
| 94 | +print("done") |
| 95 | +prefill_session.skip_buffers( |
| 96 | + [x for x in prefill_session.input_names + prefill_session.output_names if x.startswith("past_")] |
| 97 | +) |
| 98 | +logits_out_placeholder = np.zeros((1, 1, 201088), dtype=np.float32) |
| 99 | +prefill_session.set_buffers({"logits": logits_out_placeholder}) |
| 100 | +inputs.pop("past_key_values") |
| 101 | +inputs = {k: v.detach().numpy() for k, v in inputs.items()} |
| 102 | +st = time.time() |
| 103 | + |
| 104 | +for i in range(num_chunks): |
| 105 | + chunk_inputs = inputs.copy() |
| 106 | + chunk_inputs["input_ids"] = inputs["input_ids"][:, i * PREFILL_SEQ_LEN : (i + 1) * PREFILL_SEQ_LEN] |
| 107 | + chunk_inputs["position_ids"] = inputs["position_ids"][:, i * PREFILL_SEQ_LEN : (i + 1) * PREFILL_SEQ_LEN] |
| 108 | + ins = time.time() |
| 109 | + qpc_out = prefill_session.run(chunk_inputs) |
| 110 | + print(f"time for this run={time.time() - ins}") |
| 111 | +print(f"time for prefill_run={time.time() - st} sec\n") |
| 112 | + |
| 113 | +decode_session = QAICInferenceSession(decode_qpc_path) |
| 114 | +decode_session.set_buffers({"logits": logits_out_placeholder}) |
| 115 | + |
| 116 | +decode_inputs = { |
| 117 | + "input_ids": np.argmax(qpc_out["logits"]).reshape(1, 1), |
| 118 | + "position_ids": np.max(inputs["position_ids"]).reshape(1, 1) + 1, |
| 119 | +} |
| 120 | +print("pos_id for decodee", decode_inputs["position_ids"]) |
| 121 | + |
| 122 | +all_outputs.append(decode_inputs["input_ids"][0][0]) |
| 123 | +for i in range(config.num_hidden_layers): |
| 124 | + if i % 2 == 0 and decode_inputs["position_ids"] >= config.sliding_window: |
| 125 | + last_valid_pos_idx = decode_inputs["position_ids"][0][0] |
| 126 | + first_valid_pos_idx = last_valid_pos_idx - config.sliding_window |
| 127 | + k = qpc_out[f"past_key.{i}_RetainedState"][:, :, first_valid_pos_idx:last_valid_pos_idx, :] |
| 128 | + v = qpc_out[f"past_value.{i}_RetainedState"][:, :, first_valid_pos_idx:last_valid_pos_idx, :] |
| 129 | + mod_pos_id = config.sliding_window - decode_inputs["position_ids"][0][0] % config.sliding_window |
| 130 | + decode_inputs[f"past_key.{i}"] = np.concatenate((k[:, :, mod_pos_id:, :], k[:, :, :mod_pos_id, :]), axis=-2) |
| 131 | + decode_inputs[f"past_value.{i}"] = np.concatenate((v[:, :, mod_pos_id:, :], v[:, :, :mod_pos_id, :]), axis=-2) |
| 132 | + else: |
| 133 | + decode_inputs[f"past_key.{i}"] = qpc_out[f"past_key.{i}_RetainedState"] |
| 134 | + decode_inputs[f"past_value.{i}"] = qpc_out[f"past_value.{i}_RetainedState"] |
| 135 | + |
| 136 | +st = time.time() |
| 137 | +decode_out = decode_session.run(decode_inputs) |
| 138 | +print(f"time for first run of decode with KV as input = {time.time() - st} sec\n") |
| 139 | +decode_session.skip_buffers( |
| 140 | + [x for x in decode_session.input_names + decode_session.output_names if x.startswith("past_")] |
| 141 | +) |
| 142 | +pos_id = np.max(decode_inputs["position_ids"]).reshape(1, 1) + 1 |
| 143 | +st = time.time() |
| 144 | +for i in range(generation_len - 2): |
| 145 | + loop_decode_inputs = { |
| 146 | + "input_ids": np.argmax(decode_out["logits"]).reshape(1, 1), |
| 147 | + "position_ids": pos_id, |
| 148 | + } |
| 149 | + all_outputs.append(loop_decode_inputs["input_ids"][0][0]) |
| 150 | + decode_out = decode_session.run(loop_decode_inputs) |
| 151 | + pos_id += 1 |
| 152 | + |
| 153 | + |
| 154 | +print(f"time for decode generation = {(time.time() - st) / (generation_len - 2)}") |
| 155 | +print(all_outputs) |
| 156 | +print(tokenizer.decode(all_outputs)) |
0 commit comments