Skip to content

Commit 16899bb

Browse files
committed
pushed latest changes with chunking enabled for prefill along with retaining full KV for decode-only model
Signed-off-by: Onkar Chougule <ochougul@qti.qualcomm.com>
1 parent 3de97bf commit 16899bb

File tree

6 files changed

+112
-87
lines changed

6 files changed

+112
-87
lines changed

QEfficient/base/modeling_qeff.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -325,8 +325,13 @@ def get_onnx_path(
325325
specializations: Optional[List[Dict[str, int]]] = None,
326326
offload_pt_weights: Optional[bool] = True,
327327
use_onnx_subfunctions: Optional[bool] = False,
328+
retain_full_kv: Optional[bool] = False,
328329
):
329-
kwargs = {"offload_pt_weights": offload_pt_weights, "use_onnx_subfunctions": use_onnx_subfunctions}
330+
kwargs = {
331+
"offload_pt_weights": offload_pt_weights,
332+
"use_onnx_subfunctions": use_onnx_subfunctions,
333+
"retain_full_kv": retain_full_kv,
334+
}
330335
if prefill_only:
331336
if self.prefill_onnx_path is None:
332337
kwargs.update(
@@ -360,6 +365,7 @@ def _compile(
360365
prefill_only: Optional[str] = None,
361366
offload_pt_weights: Optional[bool] = True,
362367
enable_chunking: Optional[bool] = False,
368+
retain_full_kv: Optional[bool] = None,
363369
**compiler_options,
364370
) -> str:
365371
"""
@@ -389,7 +395,12 @@ def _compile(
389395
onnx_path
390396
if onnx_path
391397
else self.get_onnx_path(
392-
prefill_only, enable_chunking, specializations, offload_pt_weights, use_onnx_subfunctions
398+
prefill_only,
399+
enable_chunking,
400+
specializations,
401+
offload_pt_weights,
402+
use_onnx_subfunctions,
403+
retain_full_kv,
393404
)
394405
)
395406
compile_dir = Path(compile_dir or onnx_path.parent)

QEfficient/transformers/models/gpt_oss/modeling_gpt_oss.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1298,11 +1298,11 @@ def forward(
12981298

12991299
def get_pkv_dynamic_axes(
13001300
self,
1301-
chunked_prefill: Optional[bool] = False,
1301+
retain_full_kv: Optional[bool] = False,
13021302
):
13031303
pkv_dynamic_axes = []
13041304
for layer_type in self.config.layer_types:
1305-
if layer_type == "sliding_attention" and not chunked_prefill:
1305+
if layer_type == "sliding_attention" and not retain_full_kv:
13061306
pkv_dynamic_axes.append({0: "batch_size", 2: "sliding_window"})
13071307
else:
13081308
pkv_dynamic_axes.append({0: "batch_size", 2: "ctx_len"})

QEfficient/transformers/models/modeling_auto.py

Lines changed: 29 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@
5252
PoolingTransform,
5353
PrefillOnlyChunkedTransform,
5454
PrefillOnlyTransform,
55+
RevertPrefillKeepAttentionTransform,
5556
RevertPrefillOnlyTransform,
5657
SamplerTransform,
5758
SpDTransform,
@@ -2303,15 +2304,23 @@ class QEFFAutoModelForCausalLM(QEFFBaseModel):
23032304
SplitTensorsTransform,
23042305
]
23052306

2306-
def prefill(self, enable: Optional[bool] = True, enable_chunking: Optional[bool] = False):
2307+
def prefill(
2308+
self,
2309+
enable: Optional[bool] = True,
2310+
enable_chunking: Optional[bool] = False,
2311+
retain_full_kv: Optional[bool] = False,
2312+
):
23072313
if enable:
23082314
if enable_chunking:
23092315
self.model, tf = PrefillOnlyChunkedTransform.apply(self.model)
23102316
else:
23112317
self.model, tf = PrefillOnlyTransform.apply(self.model)
23122318
self.prefill_enabled = True
23132319
else:
2314-
self.model, tf = RevertPrefillOnlyTransform.apply(self.model)
2320+
if retain_full_kv:
2321+
self.model, tf = RevertPrefillKeepAttentionTransform.apply(self.model)
2322+
else:
2323+
self.model, tf = RevertPrefillOnlyTransform.apply(self.model)
23152324
self.prefill_enabled = False
23162325

23172326
def __init__(
@@ -2478,7 +2487,6 @@ def from_pretrained(
24782487
qaic_config["pretrained_model_name_or_path"] = pretrained_model_name_or_path
24792488

24802489
# This is support models that should be classified to in a different auto class but transformers load them via this class
2481-
24822490
if model.__class__.__name__ in MISCLASSIFIED_CAUSAL_LM_TO_QEFF_AUTO_CLASS_MAP:
24832491
return MISCLASSIFIED_CAUSAL_LM_TO_QEFF_AUTO_CLASS_MAP[model.__class__.__name__](
24842492
model,
@@ -2511,7 +2519,6 @@ def get_model_config(self) -> dict:
25112519
def get_seq_len_and_handle_specialized_prefill_model(
25122520
self, prefill_seq_len: Optional[int] = None, enable_chunking=False
25132521
) -> int:
2514-
self.prefill(enable=True, enable_chunking=enable_chunking)
25152522
self.hash_params["prefill_only"] = True
25162523
if enable_chunking:
25172524
self.hash_params["chunking"] = True
@@ -2586,6 +2593,8 @@ def export(
25862593
)
25872594
if prefill_only:
25882595
assert not self.continuous_batching, "prefill_only=True is not supported with continuous_batching=True"
2596+
self.prefill(enable=True, enable_chunking=kwargs.get("enable_chunking", False))
2597+
self.hash_params.pop("retain_full_kv", None)
25892598
seq_len = (
25902599
self.get_seq_len_and_handle_specialized_prefill_model(
25912600
prefill_seq_len=prefill_seq_len, enable_chunking=kwargs.get("enable_chunking", False)
@@ -2597,9 +2606,15 @@ def export(
25972606
seq_len + self.model.config.sliding_window if kwargs.get("enable_chunking", False) else seq_len
25982607
)
25992608
else:
2600-
self.prefill(False)
2609+
self.prefill(False, retain_full_kv=kwargs.get("retain_full_kv", False))
26012610
self.hash_params.pop("prefill_only", None)
2602-
self.hash_params.pop("num_blocks", None)
2611+
self.hash_params.pop("NUM_Q_BLOCKS", None)
2612+
self.hash_params.pop("NUM_FFN_BLOCKS", None)
2613+
self.hash_params.pop("ENABLE_OPT_SWA", None)
2614+
self.hash_params.pop("chunking", None)
2615+
if kwargs.get("retain_full_kv", False):
2616+
kv_cache_shape[2] = seq_len + self.model.config.sliding_window
2617+
self.hash_params["retain_full_kv"] = True
26032618

26042619
example_inputs = {
26052620
"input_ids": torch.zeros((bs, seq_len), dtype=torch.int64),
@@ -2649,7 +2664,10 @@ def export(
26492664
else:
26502665
# HACK: create common function for this including above if condition code
26512666
pkv_dynamic_axes = (
2652-
self.model.get_pkv_dynamic_axes(chunked_prefill=(prefill_only and kwargs.get("enable_chunking", False)))
2667+
self.model.get_pkv_dynamic_axes(
2668+
retain_full_kv=kwargs.get("retain_full_kv", False)
2669+
or (prefill_only and kwargs.get("enable_chunking", False))
2670+
)
26532671
if hasattr(self.model, "get_pkv_dynamic_axes")
26542672
else pkv_dynamic_axes
26552673
)
@@ -2905,6 +2923,7 @@ def compile(
29052923
use_onnx_subfunctions: bool = False,
29062924
offload_pt_weights: Optional[bool] = True,
29072925
enable_chunking: Optional[bool] = False,
2926+
retain_full_kv: Optional[bool] = None,
29082927
**compiler_options,
29092928
) -> str:
29102929
"""
@@ -3040,6 +3059,8 @@ def compile(
30403059
if self.comp_ctx_lengths_prefill is not None:
30413060
# Adding elements from self.comp_ctx_lengths_prefill to prefill_specialization
30423061
for i in range(0, len(self.comp_ctx_lengths_prefill)):
3062+
if prefill_only or enable_chunking:
3063+
raise NotImplementedError("prefill_only or enable_chunking is not supported with CCL")
30433064
specializations.append(
30443065
self.build_prefill_specialization(
30453066
prefill_seq_len=prefill_seq_len,
@@ -3118,6 +3139,7 @@ def compile(
31183139
prefill_only=prefill_only,
31193140
offload_pt_weights=offload_pt_weights,
31203141
enable_chunking=enable_chunking,
3142+
retain_full_kv=retain_full_kv,
31213143
**compiler_options,
31223144
)
31233145

QEfficient/transformers/models/pytorch_transforms.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -654,6 +654,16 @@ class PrefillOnlyChunkedTransform(ModuleMappingTransform):
654654
}
655655

656656

657+
class RevertPrefillKeepAttentionTransform(ModuleMappingTransform):
658+
_module_mapping = {
659+
QEffGptOssModel: QEffPrefillOnlyGptOssModel,
660+
QEffPrefillOnlyGptOssAttention: QEffPrefillOnlyChunkedGptOssAttention,
661+
QEffGptOssAttention: QEffPrefillOnlyChunkedGptOssAttention,
662+
QEffPrefillOnlyGptOssMLP: QEffGptOssMLP,
663+
QEffPrefillOnlyChunkedGptOssMLP: QEffGptOssMLP,
664+
}
665+
666+
657667
class RevertPrefillOnlyTransform(ModuleMappingTransform):
658668
_module_mapping = {
659669
**{v: k for k, v in PrefillOnlyTransform._module_mapping.items()},

examples/gpt_oss_disagg_mode_with_chunking.py

Lines changed: 56 additions & 75 deletions
Original file line numberDiff line numberDiff line change
@@ -9,12 +9,12 @@
99

1010
import numpy as np
1111
import torch
12-
from transformers import AutoTokenizer
12+
from transformers import AutoConfig, AutoTokenizer
1313

1414
from QEfficient import QEFFAutoModelForCausalLM
1515
from QEfficient.generation.cloud_infer import QAICInferenceSession
1616

17-
model_id = "openai/gpt-oss-120b" # weights are not required to convert to fp32
17+
model_id = "openai/gpt-oss-20b" # weights are not required to convert to fp32
1818

1919
prompt = """
2020
Once upon a time, in a small town, there lived a young boy named Alex. Alex was a curious and adventurous child, always eager to explore the world around him. One day, while playing in the park, Alex stumbled upon a mysterious old book hidden beneath a pile of leaves. The book was filled with stories of distant lands, magical creatures, and extraordinary adventures.
@@ -23,27 +23,14 @@
2323
2424
The path to the treasure was not an easy one. Alex had to navigate through dense forests, cross rickety bridges, and solve riddles that guarded the treasure's location.
2525
"""
26-
all_outputs = []
2726
# Run prefill
27+
config = AutoConfig.from_pretrained(model_id)
2828
tokenizer = AutoTokenizer.from_pretrained(model_id)
2929
PREFILL_SEQ_LEN = 128
30-
CTX_LEN = 2 * 128
31-
inputs = tokenizer(prompt, return_tensors="np", padding=True)
32-
position_ids = inputs["attention_mask"].sum(1, keepdims=True)
33-
padded_len = inputs["input_ids"].shape[1]
34-
num_chunks = -(padded_len // -PREFILL_SEQ_LEN) # ceil divide without float
35-
padded_len = num_chunks * PREFILL_SEQ_LEN # Convert to a multiple of prompt_len
30+
CTX_LEN = 128 * 3
3631

37-
# Initialize variables specific to request
38-
# Calculate the max generation length.
39-
max_gen_len = CTX_LEN - position_ids.max()
40-
generation_len = max_gen_len
41-
42-
43-
# qeff_model = QEFFAutoModelForCausalLM.from_pretrained(model_id, num_hidden_layers=2)
4432
qeff_model = QEFFAutoModelForCausalLM.from_pretrained(model_id)
4533

46-
4734
decode_qpc_path = qeff_model.compile(
4835
prefill_seq_len=1,
4936
ctx_len=CTX_LEN,
@@ -55,23 +42,12 @@
5542
aic_enable_depth_first=True,
5643
num_speculative_tokens=None,
5744
offload_pt_weights=False, # Need the weights in memory for prefill-model export/compilation in the next step
45+
retain_full_kv=True,
5846
)
5947

60-
config = qeff_model.model.config
61-
inputs = tokenizer(prompt, return_tensors="np", padding="max_length", max_length=padded_len)
62-
inputs["position_ids"] = np.where(inputs.pop("attention_mask"), np.arange(padded_len), -1)
63-
inputs.pop("token_type_ids", None)
64-
inputs = {k: torch.from_numpy(v) for k, v in inputs.items()}
65-
past_key_values = []
66-
for i in range(config.num_hidden_layers):
67-
cache_len = config.sliding_window if i % 2 == 0 else PREFILL_SEQ_LEN
68-
pad_shape = (1, 8, cache_len, 64)
69-
past_key = torch.zeros((pad_shape), dtype=torch.float32)
70-
past_value = torch.zeros((pad_shape), dtype=torch.float32)
71-
pkv = (past_key, past_value)
72-
past_key_values.append(pkv)
73-
inputs["past_key_values"] = past_key_values
7448

49+
# Following command errors out by default, the user is supposed to run the printed command and provide the generated qpc path as prefill_qpc_path commenting out lines 55-68
50+
# prefill_qpc_path = "provide path here"
7551
prefill_qpc_path = qeff_model.compile(
7652
prefill_seq_len=PREFILL_SEQ_LEN,
7753
ctx_len=CTX_LEN,
@@ -85,72 +61,77 @@
8561
prefill_only=True,
8662
enable_chunking=True,
8763
use_onnx_subfunctions=True,
88-
offload_pt_weights=False,
89-
)
90-
print("loading qpc")
91-
st = time.time()
92-
prefill_session = QAICInferenceSession(prefill_qpc_path, device_ids=[i for i in range(32, 48)])
93-
print(f"time for loading session = {time.time() - st}")
94-
print("done")
95-
prefill_session.skip_buffers(
96-
[x for x in prefill_session.input_names + prefill_session.output_names if x.startswith("past_")]
9764
)
98-
logits_out_placeholder = np.zeros((1, 1, 201088), dtype=np.float32)
99-
prefill_session.set_buffers({"logits": logits_out_placeholder})
100-
inputs.pop("past_key_values")
65+
66+
67+
inputs = tokenizer(prompt, return_tensors="np", padding=True)
68+
position_ids = inputs["attention_mask"].sum(1, keepdims=True)
69+
generation_len = CTX_LEN - position_ids.max()
70+
padded_len = inputs["input_ids"].shape[1]
71+
num_chunks = -(padded_len // -PREFILL_SEQ_LEN) # ceil divide without float
72+
padded_len = num_chunks * PREFILL_SEQ_LEN # Convert to a multiple of prompt_len
73+
inputs = tokenizer(prompt, return_tensors="np", padding="max_length", max_length=padded_len)
74+
inputs["position_ids"] = np.where(inputs.pop("attention_mask"), np.arange(padded_len), -1)
75+
inputs.pop("token_type_ids", None)
76+
inputs = {k: torch.from_numpy(v) for k, v in inputs.items()}
77+
inputs.pop("past_key_values", None)
10178
inputs = {k: v.detach().numpy() for k, v in inputs.items()}
102-
st = time.time()
10379

80+
81+
decode_session = QAICInferenceSession(decode_qpc_path)
82+
prefill_session = QAICInferenceSession(prefill_qpc_path)
83+
84+
all_outputs = []
10485
for i in range(num_chunks):
10586
chunk_inputs = inputs.copy()
10687
chunk_inputs["input_ids"] = inputs["input_ids"][:, i * PREFILL_SEQ_LEN : (i + 1) * PREFILL_SEQ_LEN]
10788
chunk_inputs["position_ids"] = inputs["position_ids"][:, i * PREFILL_SEQ_LEN : (i + 1) * PREFILL_SEQ_LEN]
10889
ins = time.time()
10990
qpc_out = prefill_session.run(chunk_inputs)
11091
print(f"time for this run={time.time() - ins}")
111-
print(f"time for prefill_run={time.time() - st} sec\n")
112-
113-
decode_session = QAICInferenceSession(decode_qpc_path)
114-
decode_session.set_buffers({"logits": logits_out_placeholder})
92+
for i in range(config.num_hidden_layers):
93+
inputs[f"past_key.{i}"] = qpc_out[f"past_key.{i}_RetainedState"]
94+
inputs[f"past_value.{i}"] = qpc_out[f"past_value.{i}_RetainedState"]
11595

96+
all_outputs.append(np.argmax(qpc_out["logits"]))
11697
decode_inputs = {
11798
"input_ids": np.argmax(qpc_out["logits"]).reshape(1, 1),
11899
"position_ids": np.max(inputs["position_ids"]).reshape(1, 1) + 1,
119100
}
120-
print("pos_id for decodee", decode_inputs["position_ids"])
121-
122-
all_outputs.append(decode_inputs["input_ids"][0][0])
123101
for i in range(config.num_hidden_layers):
124-
if i % 2 == 0 and decode_inputs["position_ids"] >= config.sliding_window:
125-
last_valid_pos_idx = decode_inputs["position_ids"][0][0]
126-
first_valid_pos_idx = last_valid_pos_idx - config.sliding_window
127-
k = qpc_out[f"past_key.{i}_RetainedState"][:, :, first_valid_pos_idx:last_valid_pos_idx, :]
128-
v = qpc_out[f"past_value.{i}_RetainedState"][:, :, first_valid_pos_idx:last_valid_pos_idx, :]
129-
mod_pos_id = config.sliding_window - decode_inputs["position_ids"][0][0] % config.sliding_window
130-
decode_inputs[f"past_key.{i}"] = np.concatenate((k[:, :, mod_pos_id:, :], k[:, :, :mod_pos_id, :]), axis=-2)
131-
decode_inputs[f"past_value.{i}"] = np.concatenate((v[:, :, mod_pos_id:, :], v[:, :, :mod_pos_id, :]), axis=-2)
132-
else:
133-
decode_inputs[f"past_key.{i}"] = qpc_out[f"past_key.{i}_RetainedState"]
134-
decode_inputs[f"past_value.{i}"] = qpc_out[f"past_value.{i}_RetainedState"]
102+
decode_inputs[f"past_key.{i}"] = qpc_out[f"past_key.{i}_RetainedState"]
103+
decode_inputs[f"past_value.{i}"] = qpc_out[f"past_value.{i}_RetainedState"]
135104

136105
st = time.time()
137106
decode_out = decode_session.run(decode_inputs)
138107
print(f"time for first run of decode with KV as input = {time.time() - st} sec\n")
139-
decode_session.skip_buffers(
140-
[x for x in decode_session.input_names + decode_session.output_names if x.startswith("past_")]
141-
)
108+
all_outputs.append(np.argmax(decode_out["logits"]))
142109
pos_id = np.max(decode_inputs["position_ids"]).reshape(1, 1) + 1
110+
loop_decode_inputs = {
111+
"input_ids": np.argmax(decode_out["logits"]).reshape(1, 1),
112+
"position_ids": pos_id,
113+
}
114+
115+
for i in range(config.num_hidden_layers):
116+
loop_decode_inputs[f"past_key.{i}"] = decode_out[f"past_key.{i}_RetainedState"]
117+
loop_decode_inputs[f"past_value.{i}"] = decode_out[f"past_value.{i}_RetainedState"]
118+
143119
st = time.time()
144120
for i in range(generation_len - 2):
145-
loop_decode_inputs = {
146-
"input_ids": np.argmax(decode_out["logits"]).reshape(1, 1),
147-
"position_ids": pos_id,
148-
}
149-
all_outputs.append(loop_decode_inputs["input_ids"][0][0])
150121
decode_out = decode_session.run(loop_decode_inputs)
122+
all_outputs.append(np.argmax(decode_out["logits"]))
151123
pos_id += 1
152-
153-
154-
print(f"time for decode generation = {(time.time() - st) / (generation_len - 2)}")
155-
print(all_outputs)
156-
print(tokenizer.decode(all_outputs))
124+
for i in range(config.num_hidden_layers):
125+
loop_decode_inputs[f"past_key.{i}"] = decode_out[f"past_key.{i}_RetainedState"]
126+
loop_decode_inputs[f"past_value.{i}"] = decode_out[f"past_value.{i}_RetainedState"]
127+
128+
loop_decode_inputs.update(
129+
{
130+
"input_ids": np.argmax(decode_out["logits"]).reshape(1, 1),
131+
"position_ids": pos_id,
132+
}
133+
)
134+
ft = time.time()
135+
136+
print(f"decode tok/sec={(generation_len - 2) / (ft - st)}")
137+
print(f"input\n{prompt}\noutput\n{tokenizer.decode(all_outputs)}")

tests/transformers/models/test_disagg_mode.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
from QEfficient.generation.cloud_infer import QAICInferenceSession
1717
from QEfficient.transformers.quantizers import replace_transformers_quantizers, undo_transformers_quantizers
1818

19-
model_id = "openai/gpt-oss-20b" # weights are not required to convert to fp32
19+
model_id = "openai/gpt-oss-120b" # weights are not required to convert to fp32
2020

2121
prompt2 = """
2222
Once upon a time, in a small town, there lived a young boy named Alex. Alex was a curious and adventurous child, always eager to explore the world around him. One day, while playing in the park, Alex stumbled upon a mysterious old book hidden beneath a pile of leaves. The book was filled with stories of distant lands, magical creatures, and extraordinary adventures.
@@ -104,6 +104,7 @@ def test_disagg_mode_prefill(model_id, prompt):
104104
assert (torch.from_numpy(qpc_out["logits"]) - qeff_out.logits).abs().max() < 5e-2
105105

106106

107+
@pytest.mark.skip(reason="no way of currently testing this without the assert sdk")
107108
@pytest.mark.on_qaic
108109
@pytest.mark.parametrize("model_id", [model_id])
109110
@pytest.mark.parametrize("prompt", prompts)

0 commit comments

Comments
 (0)