Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
37 commits
Select commit Hold shift + click to select a range
58b8e6e
[QEff]: Add gpt_oss
vbaddi Aug 6, 2025
582fc17
nit: update modeling and make transform uniform
vbaddi Aug 7, 2025
6352ac2
apirunner change
ochougul Aug 7, 2025
296dc9a
added test along with simplified Hybridcache
ochougul Aug 7, 2025
6c9e79c
added test assert
ochougul Aug 7, 2025
df5dd62
nit: update test gpt file
vbaddi Aug 8, 2025
f806ac5
nit: update modeling with new decode moe forward
vbaddi Aug 11, 2025
0a6aa9c
nit: seperate gate, up projections for MoE
vbaddi Aug 20, 2025
7731691
nit: remove test file and add sample test in config
Oct 15, 2025
15ebe39
Enable CB for GptOssModel
mamtsing Nov 3, 2025
52f64b4
Fix tests
mamtsing Nov 4, 2025
79cbae9
Address review comments
mamtsing Nov 4, 2025
3e2a261
prefill only changes for gpt-oss
ochougul Nov 4, 2025
0e3a673
fixed mapping
ochougul Nov 5, 2025
3ce4320
added test
ochougul Nov 6, 2025
e929616
added test
ochougul Nov 6, 2025
40ab876
made example not ugly
ochougul Nov 6, 2025
b9defbe
fixed tests
ochougul Nov 6, 2025
446f4b6
fixed tests
ochougul Nov 6, 2025
099fd61
added new test and fixed failing tests
ochougul Nov 7, 2025
4d4639e
fixed tests
ochougul Nov 10, 2025
f32df62
fixed kv cache shape
ochougul Nov 10, 2025
0b29ba4
fixed self.onnx_path issue in modeling_qeff
ochougul Nov 11, 2025
053acaa
added ffn blocking and num blocks env variables
ochougul Nov 13, 2025
8447c18
include num_ffn_blocks in hash
ochougul Nov 17, 2025
3982c9d
fixed dynamic range in case of subfunc issue and nonmatching ctx, pre…
ochougul Nov 18, 2025
4c38de3
added swa optimization for reducing MACCs using less KV
ochougul Nov 18, 2025
5e5d708
added opt swa to hash
ochougul Nov 24, 2025
88ae0be
lint and format
ochougul Nov 24, 2025
5d014c2
enabled chunking
ochougul Nov 26, 2025
f1b1785
added ChunkedPrefillMLP block; fixed passing prefill_only flag and en…
ochougul Dec 1, 2025
723f4ad
added disagg mode example for chunking mode
ochougul Dec 2, 2025
1bc8ee9
fixed the kwargs passing to build_decode_specialization
ochougul Dec 2, 2025
abccb26
pushed latest changes with chunking enabled for prefill along with re…
ochougul Dec 8, 2025
1721b9d
added support for prefix caching for gpt-oss
ochougul Dec 8, 2025
1b60a5f
removed error
ochougul Dec 9, 2025
e8d1128
added errors for prefill-only mode
ochougul Dec 9, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 15 additions & 8 deletions QEfficient/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,17 @@
# -----------------------------------------------------------------------------

import os
import warnings

# ----------------------------------------------------------------------------- #
# For faster downloads via hf_transfer
# This code is put above import statements as this needs to be executed before
# hf_transfer is imported (will happen on line 15 via leading imports)
os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"
# DO NOT ADD ANY CODE ABOVE THIS LINE
# Please contact maintainers if you must edit this file above this line.
# ----------------------------------------------------------------------------- #
# Placeholder for all non-transformer models registered in QEfficient
import warnings # noqa: I001

import QEfficient.utils.model_registery # noqa: F401
from QEfficient.base import (
Expand All @@ -25,6 +35,10 @@
from QEfficient.utils import custom_format_warning
from QEfficient.utils.logging_utils import logger

# custom warning for the better logging experience
warnings.formatwarning = custom_format_warning


# Users can use QEfficient.export for exporting models to ONNX
export = qualcomm_efficient_converter
__all__ = [
Expand All @@ -40,14 +54,7 @@
"QEFFAutoModelForSpeechSeq2Seq",
"QEFFCommonLoader",
]
# For faster downloads via hf_transfer
# This code is put above import statements as this needs to be executed before
# hf_transfer is imported (will happen on line 15 via leading imports)
os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"
# Placeholder for all non-transformer models registered in QEfficient

# custom warning for the better logging experience
warnings.formatwarning = custom_format_warning

# Conditionally import QAIC-related modules if the SDK is installed
__version__ = "0.0.1.dev0"
Expand Down
75 changes: 68 additions & 7 deletions QEfficient/base/modeling_qeff.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,8 @@ def __init__(self, model: torch.nn.Module, **kwargs) -> None:
super().__init__()
self.model = model
self.hash_params = create_model_params(self, **kwargs)
self.prefill_enabled = False
self.prefill_onnx_path: Optional[str] = None
self.onnx_path: Optional[str] = None
self.qpc_path: Optional[str] = None
self.qpc_session: Optional[QAICInferenceSession] = None
Expand Down Expand Up @@ -189,6 +191,7 @@ def _export(
export_dir: Optional[str] = None,
offload_pt_weights: bool = True,
use_onnx_subfunctions: bool = False,
prefill_only: Optional[bool] = False,
) -> str:
"""
Export the PyTorch model to ONNX and apply ONNX transforms
Expand Down Expand Up @@ -217,7 +220,10 @@ def _export(

# Return early if ONNX already exists
if onnx_path.is_file():
self.onnx_path = onnx_path
if prefill_only:
self.prefill_onnx_path = onnx_path
else:
self.onnx_path = onnx_path
return onnx_path

# check if the model is in meta state or weights are offloaded
Expand Down Expand Up @@ -315,9 +321,42 @@ def _export(
self._onnx_transforms.remove(CustomOpTransform)
self._onnx_transforms.remove(RenameFunctionOutputsTransform)

self.onnx_path = onnx_path
if prefill_only:
self.prefill_onnx_path = onnx_path
else:
self.onnx_path = onnx_path
return onnx_path

def get_onnx_path(
self,
prefill_only: Optional[bool] = False,
enable_chunking: Optional[bool] = False,
specializations: Optional[List[Dict[str, int]]] = None,
offload_pt_weights: Optional[bool] = True,
use_onnx_subfunctions: Optional[bool] = False,
retain_full_kv: Optional[bool] = False,
):
kwargs = {
"offload_pt_weights": offload_pt_weights,
"use_onnx_subfunctions": use_onnx_subfunctions,
"retain_full_kv": retain_full_kv,
}
if prefill_only:
if self.prefill_onnx_path is None:
kwargs.update(
{
"prefill_only": prefill_only,
"prefill_seq_len": specializations[0].get("seq_len"),
"enable_chunking": enable_chunking,
}
)
self.export(**kwargs)
return self.prefill_onnx_path
else:
if self.onnx_path is None:
self.export(**kwargs)
return self.onnx_path

@dump_qconfig
def _compile(
self,
Expand All @@ -332,6 +371,10 @@ def _compile(
enable_qnn: Optional[bool] = False,
qnn_config: Optional[str] = None,
use_onnx_subfunctions: bool = False,
prefill_only: Optional[str] = None,
offload_pt_weights: Optional[bool] = True,
enable_chunking: Optional[bool] = False,
retain_full_kv: Optional[bool] = None,
**compiler_options,
) -> str:
"""
Expand All @@ -357,11 +400,18 @@ def _compile(

For QNN Compilation path, when enable_qnn is set to True, any parameter passed in compiler_options will be ignored.
"""

if onnx_path is None and self.onnx_path is None:
self.export(use_onnx_subfunctions=use_onnx_subfunctions)

onnx_path = Path(onnx_path or self.onnx_path)
onnx_path = Path(
onnx_path
if onnx_path
else self.get_onnx_path(
prefill_only,
enable_chunking,
specializations,
offload_pt_weights,
use_onnx_subfunctions,
retain_full_kv,
)
)
compile_dir = Path(compile_dir or onnx_path.parent)
qpc_path = compile_dir / "qpc"
if not onnx_path.is_file():
Expand Down Expand Up @@ -423,6 +473,7 @@ def _compile(
"mdp_ts_num_devices": mdp_ts_num_devices,
"mdp_ts_json": mdp_ts_json,
"num_speculative_tokens": num_speculative_tokens,
"prefill_only": prefill_only,
}
compile_hash = hash_dict_params(compile_hash_params)

Expand Down Expand Up @@ -462,6 +513,16 @@ def _compile(

command.append(f"-aic-binary-dir={qpc_path}")
logger.info(f"Running compiler: {' '.join(command)}")
if use_onnx_subfunctions:

class FeatureNotAvailableError(Exception):
pass

exec_command = f'QAIC_COMPILER_OPTS_UNSUPPORTED="-loader-inline-all=0" {" ".join(command)}'
raise FeatureNotAvailableError(
"ONNX graph is exported with subfunctions, assert version of apps SDK should be used for compiling this model."
+ f"\nRun following command manually with assert compiler:\n{exec_command}"
)
try:
subprocess.run(command, capture_output=True, check=True)
except subprocess.CalledProcessError as e:
Expand Down
4 changes: 2 additions & 2 deletions QEfficient/base/onnx_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,10 +91,10 @@ class CustomOpTransform(BaseOnnxTransform):
"CtxScatterFunc3D": (CtxScatterFunc3D, CtxScatter3D),
"CtxGatherFunc": (CtxGatherFunc, CtxGather),
"CtxGatherFunc3D": (CtxGatherFunc3D, CtxGather3D),
"CtxScatterFuncCB": (CtxScatterFuncCB, CtxScatterCB),
"CtxScatterFuncCB3D": (CtxScatterFuncCB3D, CtxScatterCB3D),
"CtxGatherFuncCB": (CtxGatherFuncCB, CtxGatherCB),
"CtxGatherFuncCB3D": (CtxGatherFuncCB3D, CtxGatherCB3D),
"CtxScatterFuncCB": (CtxScatterFuncCB, CtxScatterCB),
"CtxGatherFuncCB": (CtxGatherFuncCB, CtxGatherCB),
}

@classmethod
Expand Down
1 change: 1 addition & 0 deletions QEfficient/customop/ctx_scatter_gather.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,7 @@ class CtxGatherFunc(torch.autograd.Function):
def forward(data: torch.Tensor, ctx_indices: torch.Tensor, comp_ctx_len: int):
batch_indices = torch.arange(data.shape[0]).view(-1, 1, 1)
head_indices = torch.arange(data.shape[1]).view(1, -1, 1)
ctx_indices = torch.where(ctx_indices == torch.iinfo(torch.int32).max, 0, ctx_indices)
return data[batch_indices, head_indices, ctx_indices]

@staticmethod
Expand Down
1 change: 1 addition & 0 deletions QEfficient/customop/ctx_scatter_gather_cb.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,7 @@ class CtxGatherFuncCB(torch.autograd.Function):
def forward(data: torch.Tensor, batch_index: torch.Tensor, ctx_indices: torch.Tensor, comp_ctx_len: int):
batch_indices = batch_index.view(-1, 1, 1)
head_indices = torch.arange(data.shape[1]).view(1, -1, 1)
ctx_indices = torch.where(ctx_indices >= data.shape[2], 0, ctx_indices)
return data[batch_indices, head_indices, ctx_indices]

@staticmethod
Expand Down
4 changes: 2 additions & 2 deletions QEfficient/peft/auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,7 +253,7 @@ def from_pretrained(cls, pretrained_name_or_path: str, *args, **kwargs):
obj = cls._from_pretrained(pretrained_name_or_path, *args, **kwargs)
return obj

def export(self, export_dir: Optional[str] = None, use_onnx_subfunctions: bool = False) -> str:
def export(self, export_dir: Optional[str] = None, **kwargs) -> str:
"""
Export the model with the active adapter to ONNX format.

Expand Down Expand Up @@ -294,7 +294,7 @@ def export(self, export_dir: Optional[str] = None, use_onnx_subfunctions: bool =
export_kwargs={"do_constant_folding": False}, # To avoid merging adapter weights with base weights
onnx_transform_kwargs={"adapter_name": self.model.active_adapter},
export_dir=export_dir,
use_onnx_subfunctions=use_onnx_subfunctions,
**kwargs,
)

def compile(
Expand Down
4 changes: 2 additions & 2 deletions QEfficient/peft/lora/auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -327,7 +327,7 @@ def _init_adapter_model(self):
# load_weight to model
self._load_adapter_weights_to_model()

def export(self, export_dir: Optional[str] = None, use_onnx_subfunctions: bool = False) -> str:
def export(self, export_dir: Optional[str] = None, **kwargs) -> str:
"""
Export the model with all loaded adapters to ONNX format using ``torch.onnx.export``.

Expand Down Expand Up @@ -387,7 +387,7 @@ def export(self, export_dir: Optional[str] = None, use_onnx_subfunctions: bool =
output_names,
dynamic_axes,
export_dir=export_dir,
use_onnx_subfunctions=use_onnx_subfunctions,
**kwargs,
)

def generate(
Expand Down
121 changes: 121 additions & 0 deletions QEfficient/transformers/cache_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ def _get_invalid_idx_value(cls):
"""
if torch.onnx.is_in_onnx_export():
if cls.SUBFUNC_ENABLED:
# TODO: should not return 0 remove this if condition, it can hurt perf
return 0
else:
return torch.iinfo(torch.int32).max
Expand Down Expand Up @@ -681,6 +682,37 @@ def to_legacy_cache(self) -> Tuple[Tuple[torch.Tensor], Tuple[torch.Tensor]]:
legacy_cache += ((self.key_cache[layer_idx], self.value_cache[layer_idx]),)
return legacy_cache

def write_only(
self,
key_states: torch.Tensor,
value_states: torch.Tensor,
layer_idx: int,
cache_kwargs: Optional[Dict[str, Any]] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
if len(self.key_cache) <= layer_idx:
self.key_cache.append(key_states)
self.value_cache.append(value_states)
k_out, v_out = key_states, value_states
else:
position_ids = cache_kwargs.get("position_ids")
is_sliding_layer = cache_kwargs.get("is_sliding")
_, _, ctx_len, _ = self.key_cache[layer_idx].shape
if is_sliding_layer:
kv_position_ids = torch.arange(ctx_len, dtype=torch.int64).reshape(1, -1)
self.key_cache[layer_idx] = CtxScatterFunc.apply(self.key_cache[layer_idx], kv_position_ids, key_states)
self.value_cache[layer_idx] = CtxScatterFunc.apply(
self.value_cache[layer_idx], kv_position_ids, value_states
)
else:
kv_position_ids = position_ids

self.key_cache[layer_idx] = CtxScatterFunc.apply(self.key_cache[layer_idx], kv_position_ids, key_states)
self.value_cache[layer_idx] = CtxScatterFunc.apply(
self.value_cache[layer_idx], kv_position_ids, value_states
)
k_out, v_out = self.key_cache[layer_idx], self.value_cache[layer_idx]
return k_out, v_out

def update(
self,
key_states: torch.Tensor,
Expand Down Expand Up @@ -747,3 +779,92 @@ def update(

v_out = torch.where(invalid_mask.unsqueeze(-1), torch.tensor(0.0, dtype=torch.float32), v_out)
return k_out, v_out

def full_cache_update_chunked(
self,
key_states: torch.Tensor,
value_states: torch.Tensor,
layer_idx: int,
cache_kwargs: Optional[Dict[str, Any]] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
position_ids = cache_kwargs.get("position_ids")
batch_index = cache_kwargs.get("batch_index")
invalid_idx_value = InvalidIndexProvider._get_invalid_idx_value()

# Scatter
if batch_index is not None:
if torch.onnx.is_in_onnx_export():
scatter_position_ids = torch.where(position_ids < 0, torch.iinfo(torch.int32).max, position_ids)
self.key_cache[layer_idx] = CtxScatterFuncCB.apply(
self.key_cache[layer_idx], batch_index, scatter_position_ids, key_states
)
self.value_cache[layer_idx] = CtxScatterFuncCB.apply(
self.value_cache[layer_idx], batch_index, scatter_position_ids, value_states
)
else:
self.key_cache[layer_idx] = CtxScatterFunc.apply(self.key_cache[layer_idx], position_ids, key_states)
self.value_cache[layer_idx] = CtxScatterFunc.apply(self.value_cache[layer_idx], position_ids, value_states)

k_out, v_out = self.key_cache[layer_idx], self.value_cache[layer_idx]

# Gather
ctx_len = cache_kwargs.get("CCL", k_out.shape[2])
ctx_indices = torch.arange(ctx_len)[None, None, ...]
gather_limit = position_ids.max(1, keepdim=True).values.unsqueeze(1)
invalid_mask = ctx_indices > gather_limit
ctx_indices = torch.where(invalid_mask, invalid_idx_value, ctx_indices)
if batch_index is not None:
k_out = CtxGatherFuncCB.apply(k_out, batch_index, ctx_indices, ctx_len)
v_out = CtxGatherFuncCB.apply(v_out, batch_index, ctx_indices, ctx_len)
else:
k_out = CtxGatherFunc.apply(k_out, ctx_indices, ctx_len)
v_out = CtxGatherFunc.apply(v_out, ctx_indices, ctx_len)
v_out = torch.where(invalid_mask.unsqueeze(-1), torch.tensor(0.0, dtype=torch.float32), v_out)

return k_out, v_out

def sliding_window_update_chunked(
self,
key_states: torch.Tensor,
value_states: torch.Tensor,
layer_idx: int,
cache_kwargs: Optional[Dict[str, Any]] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
position_ids = cache_kwargs.get("position_ids")
batch_index = cache_kwargs.get("batch_index")
invalid_idx_value = InvalidIndexProvider._get_invalid_idx_value()

if batch_index is not None:
if torch.onnx.is_in_onnx_export():
scatter_position_ids = torch.where(position_ids < 0, torch.iinfo(torch.int32).max, position_ids)
self.key_cache[layer_idx] = CtxScatterFuncCB.apply(
self.key_cache[layer_idx], batch_index, scatter_position_ids, key_states
)
self.value_cache[layer_idx] = CtxScatterFuncCB.apply(
self.value_cache[layer_idx], batch_index, scatter_position_ids, value_states
)
else:
self.key_cache[layer_idx] = CtxScatterFunc.apply(self.key_cache[layer_idx], position_ids, key_states)
self.value_cache[layer_idx] = CtxScatterFunc.apply(self.value_cache[layer_idx], position_ids, value_states)

k_out, v_out = self.key_cache[layer_idx], self.value_cache[layer_idx]
sliding_window_len = cache_kwargs.get("sliding_window")

# Gather
ctx_len = position_ids.shape[1] + sliding_window_len
ctx_indices = torch.arange(ctx_len)[None, None, ...]
first_pos_idx = position_ids[0][0]
add_idx = torch.where(first_pos_idx >= sliding_window_len, first_pos_idx - sliding_window_len, 0)
ctx_indices += add_idx
gather_limit = position_ids.max(1, keepdim=True).values.unsqueeze(1)
invalid_mask = ctx_indices > gather_limit
ctx_indices = torch.where(invalid_mask, invalid_idx_value, ctx_indices)
if batch_index is not None:
k_out = CtxGatherFuncCB.apply(k_out, batch_index, ctx_indices, ctx_len)
v_out = CtxGatherFuncCB.apply(v_out, batch_index, ctx_indices, ctx_len)
else:
k_out = CtxGatherFunc.apply(k_out, ctx_indices, ctx_len)
v_out = CtxGatherFunc.apply(v_out, ctx_indices, ctx_len)
v_out = torch.where(invalid_mask.unsqueeze(-1), torch.tensor(0.0, dtype=torch.float32), v_out)

return k_out, v_out
3 changes: 3 additions & 0 deletions QEfficient/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,9 @@
# This is for supporting different seq_len for different layers for Sliding window attn, chunked attn etc.
DYNAMIC_SEQ_LEN_SUPPORTED_MODEL_ARCH = {"gemma3", "llama4", "gemma3_text", "llama4_text"}

# This is for supporting different modelling classes specially written for prefill-only model
SPECIALIZED_PREFILL_ONLY_MODEL_ARCH = {"gpt_oss"}

# Define a transformers layers to QEff layers dictionary
# While onboarding new models make sure to add the new layer maps to this dictionary.
TransformersToQEffModulesDict: Dict[Type[nn.Module], Type[nn.Module]] = {
Expand Down
Loading
Loading