Skip to content

Commit 21b5496

Browse files
committed
added support for prefix caching for gpt-oss
Signed-off-by: Onkar Chougule <ochougul@qti.qualcomm.com>
1 parent 16899bb commit 21b5496

File tree

6 files changed

+80
-40
lines changed

6 files changed

+80
-40
lines changed

QEfficient/base/modeling_qeff.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -511,8 +511,8 @@ class FeatureNotAvailableError(Exception):
511511

512512
exec_command = f'QAIC_COMPILER_OPTS_UNSUPPORTED="-loader-inline-all=0" {" ".join(command)}'
513513
raise FeatureNotAvailableError(
514-
f"ONNX graph is exported with subfunctions, assert version of apps SDK should be used for compiling this model. \
515-
Run following command manually with assert compiler:\n{exec_command}"
514+
"ONNX graph is exported with subfunctions, assert version of apps SDK should be used for compiling this model."
515+
+ f"\nRun following command manually with assert compiler:\n{exec_command}"
516516
)
517517
try:
518518
subprocess.run(command, capture_output=True, check=True)

QEfficient/base/onnx_transforms.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,16 @@
1212
from onnx import ModelProto, external_data_helper, numpy_helper
1313

1414
from QEfficient.customop.ctx_scatter_gather import CtxGather, CtxGatherFunc, CtxScatter, CtxScatterFunc
15+
from QEfficient.customop.ctx_scatter_gather_cb import (
16+
CtxGatherCB,
17+
CtxGatherCB3D,
18+
CtxGatherFuncCB,
19+
CtxGatherFuncCB3D,
20+
CtxScatterCB,
21+
CtxScatterCB3D,
22+
CtxScatterFuncCB,
23+
CtxScatterFuncCB3D,
24+
)
1525
from QEfficient.customop.rms_norm import CustomRMSNorm, CustomRMSNormFunc
1626

1727

@@ -114,6 +124,10 @@ class CustomOpTransform(OnnxTransform):
114124
"CustomRMSNormFunc": (CustomRMSNormFunc, CustomRMSNorm),
115125
"CtxScatterFunc": (CtxScatterFunc, CtxScatter),
116126
"CtxGatherFunc": (CtxGatherFunc, CtxGather),
127+
"CtxScatterFuncCB3D": (CtxScatterFuncCB3D, CtxScatterCB3D),
128+
"CtxGatherFuncCB3D": (CtxGatherFuncCB3D, CtxGatherCB3D),
129+
"CtxScatterFuncCB": (CtxScatterFuncCB, CtxScatterCB),
130+
"CtxGatherFuncCB": (CtxGatherFuncCB, CtxGatherCB),
117131
}
118132

119133
@classmethod

QEfficient/customop/ctx_scatter_gather_cb.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -126,6 +126,7 @@ class CtxGatherFuncCB(torch.autograd.Function):
126126
def forward(data: torch.Tensor, batch_index: torch.Tensor, ctx_indices: torch.Tensor, comp_ctx_len: int):
127127
batch_indices = batch_index.view(-1, 1, 1)
128128
head_indices = torch.arange(data.shape[1]).view(1, -1, 1)
129+
ctx_indices = torch.where(ctx_indices >= data.shape[2], 0, ctx_indices)
129130
return data[batch_indices, head_indices, ctx_indices]
130131

131132
@staticmethod

QEfficient/transformers/cache_utils.py

Lines changed: 43 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ def _get_invalid_idx_value(cls):
4444
"""
4545
if torch.onnx.is_in_onnx_export():
4646
if cls.SUBFUNC_ENABLED:
47+
# TODO: should not return 0 remove this if condition, it can hurt perf
4748
return 0
4849
else:
4950
return torch.iinfo(torch.int32).max
@@ -722,9 +723,22 @@ def full_cache_update_chunked(
722723
cache_kwargs: Optional[Dict[str, Any]] = None,
723724
) -> Tuple[torch.Tensor, torch.Tensor]:
724725
position_ids = cache_kwargs.get("position_ids")
726+
batch_index = cache_kwargs.get("batch_index")
727+
invalid_idx_value = InvalidIndexProvider._get_invalid_idx_value()
725728

726-
self.key_cache[layer_idx] = CtxScatterFunc.apply(self.key_cache[layer_idx], position_ids, key_states)
727-
self.value_cache[layer_idx] = CtxScatterFunc.apply(self.value_cache[layer_idx], position_ids, value_states)
729+
# Scatter
730+
if batch_index is not None:
731+
if torch.onnx.is_in_onnx_export():
732+
scatter_position_ids = torch.where(position_ids < 0, torch.iinfo(torch.int32).max, position_ids)
733+
self.key_cache[layer_idx] = CtxScatterFuncCB.apply(
734+
self.key_cache[layer_idx], batch_index, scatter_position_ids, key_states
735+
)
736+
self.value_cache[layer_idx] = CtxScatterFuncCB.apply(
737+
self.value_cache[layer_idx], batch_index, scatter_position_ids, value_states
738+
)
739+
else:
740+
self.key_cache[layer_idx] = CtxScatterFunc.apply(self.key_cache[layer_idx], position_ids, key_states)
741+
self.value_cache[layer_idx] = CtxScatterFunc.apply(self.value_cache[layer_idx], position_ids, value_states)
728742

729743
k_out, v_out = self.key_cache[layer_idx], self.value_cache[layer_idx]
730744

@@ -733,11 +747,13 @@ def full_cache_update_chunked(
733747
ctx_indices = torch.arange(ctx_len)[None, None, ...]
734748
gather_limit = position_ids.max(1, keepdim=True).values.unsqueeze(1)
735749
invalid_mask = ctx_indices > gather_limit
736-
737-
invalid_idx_value = InvalidIndexProvider._get_invalid_idx_value()
738750
ctx_indices = torch.where(invalid_mask, invalid_idx_value, ctx_indices)
739-
k_out = CtxGatherFunc.apply(k_out, ctx_indices, ctx_len)
740-
v_out = CtxGatherFunc.apply(v_out, ctx_indices, ctx_len)
751+
if batch_index is not None:
752+
k_out = CtxGatherFuncCB.apply(k_out, batch_index, ctx_indices, ctx_len)
753+
v_out = CtxGatherFuncCB.apply(v_out, batch_index, ctx_indices, ctx_len)
754+
else:
755+
k_out = CtxGatherFunc.apply(k_out, ctx_indices, ctx_len)
756+
v_out = CtxGatherFunc.apply(v_out, ctx_indices, ctx_len)
741757
v_out = torch.where(invalid_mask.unsqueeze(-1), torch.tensor(0.0, dtype=torch.float32), v_out)
742758

743759
return k_out, v_out
@@ -750,26 +766,40 @@ def sliding_window_update_chunked(
750766
cache_kwargs: Optional[Dict[str, Any]] = None,
751767
) -> Tuple[torch.Tensor, torch.Tensor]:
752768
position_ids = cache_kwargs.get("position_ids")
769+
batch_index = cache_kwargs.get("batch_index")
770+
invalid_idx_value = InvalidIndexProvider._get_invalid_idx_value()
753771

754-
self.key_cache[layer_idx] = CtxScatterFunc.apply(self.key_cache[layer_idx], position_ids, key_states)
755-
self.value_cache[layer_idx] = CtxScatterFunc.apply(self.value_cache[layer_idx], position_ids, value_states)
772+
if batch_index is not None:
773+
if torch.onnx.is_in_onnx_export():
774+
scatter_position_ids = torch.where(position_ids < 0, torch.iinfo(torch.int32).max, position_ids)
775+
self.key_cache[layer_idx] = CtxScatterFuncCB.apply(
776+
self.key_cache[layer_idx], batch_index, scatter_position_ids, key_states
777+
)
778+
self.value_cache[layer_idx] = CtxScatterFuncCB.apply(
779+
self.value_cache[layer_idx], batch_index, scatter_position_ids, value_states
780+
)
781+
else:
782+
self.key_cache[layer_idx] = CtxScatterFunc.apply(self.key_cache[layer_idx], position_ids, key_states)
783+
self.value_cache[layer_idx] = CtxScatterFunc.apply(self.value_cache[layer_idx], position_ids, value_states)
756784

757785
k_out, v_out = self.key_cache[layer_idx], self.value_cache[layer_idx]
758786
sliding_window_len = cache_kwargs.get("sliding_window")
787+
759788
# Gather
760789
ctx_len = position_ids.shape[1] + sliding_window_len
761790
ctx_indices = torch.arange(ctx_len)[None, None, ...]
762-
# positive_pos_ids = torch.where(position_ids<0, 0, position_ids)
763791
first_pos_idx = position_ids[0][0]
764792
add_idx = torch.where(first_pos_idx >= sliding_window_len, first_pos_idx - sliding_window_len, 0)
765793
ctx_indices += add_idx
766794
gather_limit = position_ids.max(1, keepdim=True).values.unsqueeze(1)
767795
invalid_mask = ctx_indices > gather_limit
768-
769-
invalid_idx_value = InvalidIndexProvider._get_invalid_idx_value()
770796
ctx_indices = torch.where(invalid_mask, invalid_idx_value, ctx_indices)
771-
k_out = CtxGatherFunc.apply(k_out, ctx_indices, ctx_len)
772-
v_out = CtxGatherFunc.apply(v_out, ctx_indices, ctx_len)
797+
if batch_index is not None:
798+
k_out = CtxGatherFuncCB.apply(k_out, batch_index, ctx_indices, ctx_len)
799+
v_out = CtxGatherFuncCB.apply(v_out, batch_index, ctx_indices, ctx_len)
800+
else:
801+
k_out = CtxGatherFunc.apply(k_out, ctx_indices, ctx_len)
802+
v_out = CtxGatherFunc.apply(v_out, ctx_indices, ctx_len)
773803
v_out = torch.where(invalid_mask.unsqueeze(-1), torch.tensor(0.0, dtype=torch.float32), v_out)
774804

775805
return k_out, v_out

QEfficient/transformers/models/gpt_oss/modeling_gpt_oss.py

Lines changed: 8 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -92,8 +92,7 @@ def forward(self, hidden: torch.Tensor):
9292
down_out = (intermediate @ W_d) + b_d # [T, H]
9393

9494
# Apply routing weights and accumulate
95-
masked_down = torch.where(routing_weight > 0, down_out * routing_weight, torch.zeros_like(expert_out))
96-
expert_out += masked_down
95+
expert_out += down_out * routing_weight
9796

9897
# original shape [B, S, H]
9998
return expert_out.view(B, S, H), router_logits
@@ -148,8 +147,7 @@ def forward(self, hidden: torch.Tensor):
148147
down_out = (intermediate @ W_d) + b_d # [T, H]
149148

150149
# Apply routing weights and accumulate
151-
masked_down = torch.where(routing_weight > 0, down_out * routing_weight, torch.zeros_like(expert_out))
152-
expert_out += masked_down
150+
expert_out += down_out * routing_weight
153151

154152
# original shape [B, S, H]
155153
return expert_out.view(B, S, H), router_logits
@@ -221,8 +219,7 @@ def blocked_ffn_forward(self, hidden: torch.Tensor):
221219
down_out = torch.cat(outs, dim=0)
222220

223221
# Apply routing weights and accumulate
224-
masked_down = torch.where(routing_weight > 0, down_out * routing_weight, torch.zeros_like(expert_out))
225-
expert_out += masked_down
222+
expert_out += down_out * routing_weight
226223

227224
# original shape [B, S, H]
228225
return expert_out.view(B, S, H), router_logits
@@ -1296,16 +1293,15 @@ def forward(
12961293
router_logits=outputs.router_logits,
12971294
)
12981295

1299-
def get_pkv_dynamic_axes(
1300-
self,
1301-
retain_full_kv: Optional[bool] = False,
1302-
):
1296+
def get_pkv_dynamic_axes(self, retain_full_kv: Optional[bool] = False, continuous_batching: Optional[bool] = False):
13031297
pkv_dynamic_axes = []
13041298
for layer_type in self.config.layer_types:
13051299
if layer_type == "sliding_attention" and not retain_full_kv:
1306-
pkv_dynamic_axes.append({0: "batch_size", 2: "sliding_window"})
1300+
pkv_dynamic_axes.append(
1301+
{0: "full_batch_size" if continuous_batching else "batch_size", 2: "sliding_window"}
1302+
)
13071303
else:
1308-
pkv_dynamic_axes.append({0: "batch_size", 2: "ctx_len"})
1304+
pkv_dynamic_axes.append({0: "full_batch_size" if continuous_batching else "batch_size", 2: "ctx_len"})
13091305
return pkv_dynamic_axes
13101306

13111307
def get_specializations(

QEfficient/transformers/models/modeling_auto.py

Lines changed: 12 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -2592,7 +2592,6 @@ def export(
25922592
self.model.config, fbs if self.continuous_batching else bs, seq_len
25932593
)
25942594
if prefill_only:
2595-
assert not self.continuous_batching, "prefill_only=True is not supported with continuous_batching=True"
25962595
self.prefill(enable=True, enable_chunking=kwargs.get("enable_chunking", False))
25972596
self.hash_params.pop("retain_full_kv", None)
25982597
seq_len = (
@@ -2666,7 +2665,8 @@ def export(
26662665
pkv_dynamic_axes = (
26672666
self.model.get_pkv_dynamic_axes(
26682667
retain_full_kv=kwargs.get("retain_full_kv", False)
2669-
or (prefill_only and kwargs.get("enable_chunking", False))
2668+
or (prefill_only and kwargs.get("enable_chunking", False)),
2669+
continuous_batching=self.continuous_batching,
26702670
)
26712671
if hasattr(self.model, "get_pkv_dynamic_axes")
26722672
else pkv_dynamic_axes
@@ -2678,7 +2678,6 @@ def export(
26782678
)
26792679

26802680
for i in range(self.num_layers):
2681-
pkv_dynamic_axes[i][0] = "full_batch_size" if self.continuous_batching else "batch_size"
26822681
for kv in ["key", "value"]:
26832682
example_inputs["past_key_values"][i].append(torch.zeros(kv_cache_shape, dtype=torch.float32))
26842683
dynamic_axes[f"past_{kv}.{i}"] = pkv_dynamic_axes[i]
@@ -3030,15 +3029,6 @@ def compile(
30303029
if self.is_tlm:
30313030
num_speculative_tokens = self.check_and_get_num_speculative_tokens(num_speculative_tokens, prefill_seq_len)
30323031

3033-
if self.continuous_batching and full_batch_size is None:
3034-
raise TypeError("`full_batch_size` is required when `continuous_batching=True`.")
3035-
3036-
if kv_cache_batch_size and not full_batch_size:
3037-
raise ValueError(
3038-
"KV caching requires continuous batching. Please set `full_batch_size` and "
3039-
"enable `continuous_batching=True` in `from_pretrained`."
3040-
)
3041-
30423032
if (
30433033
self.model.qaic_config is not None
30443034
and self.model.qaic_config.get("include_sampler", False)
@@ -3048,7 +3038,9 @@ def compile(
30483038
raise ValueError("Currently, sampler does not support `num_speculative_tokens` > 0.")
30493039

30503040
if kv_cache_batch_size and prefill_only is not None and prefill_only:
3051-
logger.warning("kv_cache_batch_size will be ignored as prefill_only is set to True")
3041+
logger.warning(
3042+
"kv_cache_batch_size will be ignored as prefill_only is set to True unless this is GPTOSS model"
3043+
)
30523044

30533045
# Infer kv_cache_batch_size if not provided
30543046
kv_cache_batch_size = kv_cache_batch_size or full_batch_size or batch_size
@@ -3086,6 +3078,13 @@ def compile(
30863078
)
30873079

30883080
if prefill_only is None or not prefill_only:
3081+
if self.continuous_batching and full_batch_size is None:
3082+
raise TypeError("`full_batch_size` is required when `continuous_batching=True`.")
3083+
if kv_cache_batch_size and not full_batch_size:
3084+
raise ValueError(
3085+
"KV caching requires continuous batching. Please set `full_batch_size` and "
3086+
"enable `continuous_batching=True` in `from_pretrained`."
3087+
)
30893088
if self.comp_ctx_lengths_decode is not None:
30903089
# Adding elements from self.comp_ctx_lengths_decode to decode_specialization
30913090
for i in range(0, len(self.comp_ctx_lengths_decode)):

0 commit comments

Comments
 (0)