Skip to content

Commit ced0683

Browse files
committed
added disagg mode example for chunking mode
Signed-off-by: Onkar Chougule <ochougul@qti.qualcomm.com>
1 parent 631d988 commit ced0683

File tree

4 files changed

+161
-4
lines changed

4 files changed

+161
-4
lines changed

QEfficient/transformers/models/gpt_oss/modeling_gpt_oss.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,6 @@ def forward(self, hidden: torch.Tensor):
9999
return expert_out.view(B, S, H), router_logits
100100

101101

102-
103102
class QEffPrefillOnlyGptOssMLP(GptOssMLP):
104103
def forward(self, hidden: torch.Tensor):
105104
if os.environ.get("NUM_FFN_BLOCKS", None) is not None:

QEfficient/transformers/models/pytorch_transforms.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -655,9 +655,10 @@ class PrefillOnlyChunkedTransform(ModuleMappingTransform):
655655

656656

657657
class RevertPrefillOnlyTransform(ModuleMappingTransform):
658-
_module_mapping = {v: k for k, v in PrefillOnlyTransform._module_mapping.items()}.update(
659-
{v: k for k, v in PrefillOnlyChunkedTransform._module_mapping.items()}
660-
)
658+
_module_mapping = {
659+
**{v: k for k, v in PrefillOnlyTransform._module_mapping.items()},
660+
**{v: k for k, v in PrefillOnlyChunkedTransform._module_mapping.items()},
661+
}
661662

662663

663664
class SpDTransform:

examples/gpt_oss_disagg_mode.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,7 @@
8080
aic_enable_depth_first=True,
8181
num_speculative_tokens=None,
8282
prefill_only=True,
83+
use_onnx_subfunctions=True,
8384
)
8485

8586
prefill_session = QAICInferenceSession(prefill_qpc_path)
Lines changed: 156 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,156 @@
1+
# -----------------------------------------------------------------------------
2+
#
3+
# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries.
4+
# SPDX-License-Identifier: BSD-3-Clause
5+
#
6+
# -----------------------------------------------------------------------------
7+
8+
import time
9+
10+
import numpy as np
11+
import torch
12+
from transformers import AutoTokenizer
13+
14+
from QEfficient import QEFFAutoModelForCausalLM
15+
from QEfficient.generation.cloud_infer import QAICInferenceSession
16+
17+
model_id = "openai/gpt-oss-120b" # weights are not required to convert to fp32
18+
19+
prompt = """
20+
Once upon a time, in a small town, there lived a young boy named Alex. Alex was a curious and adventurous child, always eager to explore the world around him. One day, while playing in the park, Alex stumbled upon a mysterious old book hidden beneath a pile of leaves. The book was filled with stories of distant lands, magical creatures, and extraordinary adventures.
21+
22+
As Alex flipped through the pages, he discovered a map that led to a hidden treasure. Excited by the prospect of a real-life treasure hunt, Alex decided to embark on a thrilling journey. He packed his backpack with snacks, a flashlight, and a compass, and set off into the unknown.
23+
24+
The path to the treasure was not an easy one. Alex had to navigate through dense forests, cross rickety bridges, and solve riddles that guarded the treasure's location.
25+
"""
26+
all_outputs = []
27+
# Run prefill
28+
tokenizer = AutoTokenizer.from_pretrained(model_id)
29+
PREFILL_SEQ_LEN = 128
30+
CTX_LEN = 2 * 128
31+
inputs = tokenizer(prompt, return_tensors="np", padding=True)
32+
position_ids = inputs["attention_mask"].sum(1, keepdims=True)
33+
padded_len = inputs["input_ids"].shape[1]
34+
num_chunks = -(padded_len // -PREFILL_SEQ_LEN) # ceil divide without float
35+
padded_len = num_chunks * PREFILL_SEQ_LEN # Convert to a multiple of prompt_len
36+
37+
# Initialize variables specific to request
38+
# Calculate the max generation length.
39+
max_gen_len = CTX_LEN - position_ids.max()
40+
generation_len = max_gen_len
41+
42+
43+
# qeff_model = QEFFAutoModelForCausalLM.from_pretrained(model_id, num_hidden_layers=2)
44+
qeff_model = QEFFAutoModelForCausalLM.from_pretrained(model_id)
45+
46+
47+
decode_qpc_path = qeff_model.compile(
48+
prefill_seq_len=1,
49+
ctx_len=CTX_LEN,
50+
num_cores=16,
51+
mxfp6_matmul=True,
52+
mxint8_kv_cache=True,
53+
num_devices=1,
54+
mos=1,
55+
aic_enable_depth_first=True,
56+
num_speculative_tokens=None,
57+
offload_pt_weights=False, # Need the weights in memory for prefill-model export/compilation in the next step
58+
)
59+
60+
config = qeff_model.model.config
61+
inputs = tokenizer(prompt, return_tensors="np", padding="max_length", max_length=padded_len)
62+
inputs["position_ids"] = np.where(inputs.pop("attention_mask"), np.arange(padded_len), -1)
63+
inputs.pop("token_type_ids", None)
64+
inputs = {k: torch.from_numpy(v) for k, v in inputs.items()}
65+
past_key_values = []
66+
for i in range(config.num_hidden_layers):
67+
cache_len = config.sliding_window if i % 2 == 0 else PREFILL_SEQ_LEN
68+
pad_shape = (1, 8, cache_len, 64)
69+
past_key = torch.zeros((pad_shape), dtype=torch.float32)
70+
past_value = torch.zeros((pad_shape), dtype=torch.float32)
71+
pkv = (past_key, past_value)
72+
past_key_values.append(pkv)
73+
inputs["past_key_values"] = past_key_values
74+
75+
prefill_qpc_path = qeff_model.compile(
76+
prefill_seq_len=PREFILL_SEQ_LEN,
77+
ctx_len=CTX_LEN,
78+
num_cores=16,
79+
mxfp6_matmul=True,
80+
mxint8_kv_cache=True,
81+
num_devices=1,
82+
mos=1,
83+
aic_enable_depth_first=True,
84+
num_speculative_tokens=None,
85+
prefill_only=True,
86+
enable_chunking=True,
87+
use_onnx_subfunctions=True,
88+
offload_pt_weights=False,
89+
)
90+
print("loading qpc")
91+
st = time.time()
92+
prefill_session = QAICInferenceSession(prefill_qpc_path, device_ids=[i for i in range(32, 48)])
93+
print(f"time for loading session = {time.time() - st}")
94+
print("done")
95+
prefill_session.skip_buffers(
96+
[x for x in prefill_session.input_names + prefill_session.output_names if x.startswith("past_")]
97+
)
98+
logits_out_placeholder = np.zeros((1, 1, 201088), dtype=np.float32)
99+
prefill_session.set_buffers({"logits": logits_out_placeholder})
100+
inputs.pop("past_key_values")
101+
inputs = {k: v.detach().numpy() for k, v in inputs.items()}
102+
st = time.time()
103+
104+
for i in range(num_chunks):
105+
chunk_inputs = inputs.copy()
106+
chunk_inputs["input_ids"] = inputs["input_ids"][:, i * PREFILL_SEQ_LEN : (i + 1) * PREFILL_SEQ_LEN]
107+
chunk_inputs["position_ids"] = inputs["position_ids"][:, i * PREFILL_SEQ_LEN : (i + 1) * PREFILL_SEQ_LEN]
108+
ins = time.time()
109+
qpc_out = prefill_session.run(chunk_inputs)
110+
print(f"time for this run={time.time() - ins}")
111+
print(f"time for prefill_run={time.time() - st} sec\n")
112+
113+
decode_session = QAICInferenceSession(decode_qpc_path)
114+
decode_session.set_buffers({"logits": logits_out_placeholder})
115+
116+
decode_inputs = {
117+
"input_ids": np.argmax(qpc_out["logits"]).reshape(1, 1),
118+
"position_ids": np.max(inputs["position_ids"]).reshape(1, 1) + 1,
119+
}
120+
print("pos_id for decodee", decode_inputs["position_ids"])
121+
122+
all_outputs.append(decode_inputs["input_ids"][0][0])
123+
for i in range(config.num_hidden_layers):
124+
if i % 2 == 0 and decode_inputs["position_ids"] >= config.sliding_window:
125+
last_valid_pos_idx = decode_inputs["position_ids"][0][0]
126+
first_valid_pos_idx = last_valid_pos_idx - config.sliding_window
127+
k = qpc_out[f"past_key.{i}_RetainedState"][:, :, first_valid_pos_idx:last_valid_pos_idx, :]
128+
v = qpc_out[f"past_value.{i}_RetainedState"][:, :, first_valid_pos_idx:last_valid_pos_idx, :]
129+
mod_pos_id = config.sliding_window - decode_inputs["position_ids"][0][0] % config.sliding_window
130+
decode_inputs[f"past_key.{i}"] = np.concatenate((k[:, :, mod_pos_id:, :], k[:, :, :mod_pos_id, :]), axis=-2)
131+
decode_inputs[f"past_value.{i}"] = np.concatenate((v[:, :, mod_pos_id:, :], v[:, :, :mod_pos_id, :]), axis=-2)
132+
else:
133+
decode_inputs[f"past_key.{i}"] = qpc_out[f"past_key.{i}_RetainedState"]
134+
decode_inputs[f"past_value.{i}"] = qpc_out[f"past_value.{i}_RetainedState"]
135+
136+
st = time.time()
137+
decode_out = decode_session.run(decode_inputs)
138+
print(f"time for first run of decode with KV as input = {time.time() - st} sec\n")
139+
decode_session.skip_buffers(
140+
[x for x in decode_session.input_names + decode_session.output_names if x.startswith("past_")]
141+
)
142+
pos_id = np.max(decode_inputs["position_ids"]).reshape(1, 1) + 1
143+
st = time.time()
144+
for i in range(generation_len - 2):
145+
loop_decode_inputs = {
146+
"input_ids": np.argmax(decode_out["logits"]).reshape(1, 1),
147+
"position_ids": pos_id,
148+
}
149+
all_outputs.append(loop_decode_inputs["input_ids"][0][0])
150+
decode_out = decode_session.run(loop_decode_inputs)
151+
pos_id += 1
152+
153+
154+
print(f"time for decode generation = {(time.time() - st) / (generation_len - 2)}")
155+
print(all_outputs)
156+
print(tokenizer.decode(all_outputs))

0 commit comments

Comments
 (0)