Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 32 additions & 0 deletions .github/workflows/quickcheck.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
name: Quickcheck

on:
pull_request:
workflow_dispatch:

concurrency:
group: quickcheck-${{ github.event.pull_request.number || github.ref }}
cancel-in-progress: true

jobs:
quickcheck:
runs-on: ubuntu-latest
timeout-minutes: 90
steps:
- name: Checkout Repo
uses: actions/checkout@v4

- name: Setup Python
uses: actions/setup-python@v5
with:
python-version: "3.10"
cache: "pip"

- name: Install Dependencies
run: |
python -m pip install --upgrade pip
python -m pip install -e .[test]
python -m pip install pytest-xdist

- name: Run Quickcheck
run: python -m pytest -q tests/test_model_quickcheck.py -n auto
28 changes: 17 additions & 11 deletions QEfficient/base/modeling_qeff.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,12 @@
import onnx
import torch

from QEfficient.base.onnx_transforms import BaseOnnxTransform, OnnxTransformPipeline
from QEfficient.base.onnx_transforms import (
BaseOnnxTransform,
FP16ClipTransform,
OnnxTransformPipeline,
SplitTensorsTransform,
)
from QEfficient.base.pytorch_transforms import PytorchTransform
from QEfficient.compile.qnn_compiler import compile as qnn_compile
from QEfficient.generation.cloud_infer import QAICInferenceSession
Expand Down Expand Up @@ -49,9 +54,8 @@ class QEFFBaseModel(ABC):
_pytorch_transforms: List[PytorchTransform]
_onnx_transforms = [BaseOnnxTransform]

@classmethod
def _transform_names(cls) -> List[str]:
return [x.__name__ for x in cls._pytorch_transforms + cls._onnx_transforms]
def _transform_names(self) -> List[str]:
return [x.__name__ for x in self._pytorch_transforms + self._onnx_transforms]

def __init__(self, model: torch.nn.Module, **kwargs) -> None:
super().__init__()
Expand Down Expand Up @@ -242,9 +246,7 @@ def _export(
# check if the model is in meta state or weights are offloaded
self._model_offloaded_check()

# Export directly into export_dir so any external data files are retained.
export_dir.mkdir(parents=True, exist_ok=True)
tmp_onnx_path = onnx_path

# Create input_names from example_inputs
input_names = []
Expand Down Expand Up @@ -274,7 +276,7 @@ def _export(
torch.onnx.export(
self.model,
(example_inputs,),
str(tmp_onnx_path),
str(onnx_path),
input_names=input_names,
output_names=output_names,
dynamic_axes=dynamic_axes,
Expand All @@ -283,11 +285,13 @@ def _export(
)
logger.info("PyTorch export successful")
_ = self._offload_model_weights(offload_pt_weights)
model = onnx.load(tmp_onnx_path, load_external_data=False)
model = onnx.load(onnx_path, load_external_data=False)

# Clear temporary references
needs_external_tensor_data = any(
transform in self._onnx_transforms for transform in (FP16ClipTransform, SplitTensorsTransform)
)
transform_kwargs = {
"onnx_base_dir": str(export_dir),
"onnx_base_dir": str(export_dir) if needs_external_tensor_data else None,
"model_name": self.model_name,
}
if onnx_transform_kwargs is not None:
Expand All @@ -302,7 +306,9 @@ def _export(
)
logger.info("ONNX transforms applied")

onnx.save(model, onnx_path)
onnx_path_tmp = onnx_path.with_suffix(onnx_path.suffix + ".tmp")
onnx.save(model, onnx_path_tmp)
onnx_path_tmp.replace(onnx_path)
del model
gc.collect()
logger.info("Transformed ONNX saved")
Expand Down
19 changes: 14 additions & 5 deletions QEfficient/base/onnx_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@

import logging
import os
import warnings
from concurrent.futures import ThreadPoolExecutor, as_completed
from typing import Any, Dict, List, Optional, Tuple, Type

Expand Down Expand Up @@ -106,16 +105,27 @@ class CustomOpTransform(BaseOnnxTransform):
@classmethod
def apply(cls, model: ModelProto) -> bool:
op_applied = False

# Register with PyTorch ONNX exporter (for export time)
for op_name, (func_class, _) in cls._custom_ops.items():
if hasattr(func_class, "symbolic"):
torch.onnx.register_custom_op_symbolic(f"::{op_name}", func_class.symbolic, ONNX_EXPORT_OPSET)

used_op_types = {node.op_type for node in model.graph.node}
for function_proto in model.functions:
used_op_types.update(node.op_type for node in function_proto.node)

# Add function prototypes to model
existing = {f.name for f in model.functions}
for _, onnxscript_func in cls._custom_ops.values():

for func_name, onnxscript_func in cls._custom_ops.values():
proto = onnxscript_func.to_function_proto()
if proto.name not in used_op_types:
continue
if proto.name not in existing:
model.functions.append(proto)
op_applied = True

return op_applied


Expand Down Expand Up @@ -202,8 +212,6 @@ class OnnxTransformPipeline(BaseOnnxTransform):
"""Pipeline to apply multiple ONNX transformations in sequence."""

def __init__(self, transforms: List[Type[BaseOnnxTransform]]):
if not transforms:
warnings.warn("Transform list is empty. No transformations will be applied.")
self.transforms = transforms

def apply(
Expand All @@ -228,7 +236,8 @@ def apply(
do_split = SplitTensorsTransform in requested
fp16_min, fp16_max = np.finfo(np.float16).min, np.finfo(np.float16).max
file_num_tracker = {"num": 0, "size": 0}
external_data_helper.load_external_data_for_model(model, onnx_base_dir)
if onnx_base_dir is not None:
external_data_helper.load_external_data_for_model(model, onnx_base_dir)

if do_fp16 or do_split:
for tensor in external_data_helper._get_all_tensors(model):
Expand Down
31 changes: 31 additions & 0 deletions tests/test_model_quickcheck.py
Original file line number Diff line number Diff line change
Expand Up @@ -444,6 +444,37 @@ def test_causal_subfunction_export_smoke(tmp_path):
assert not any("QEffGPT2Block" in name for name in without_names)


@pytest.mark.llm_model
@pytest.mark.parametrize(
("model_type", "model_id"),
sorted(CAUSAL_RUNTIME_MODEL_IDS.items()),
ids=sorted(CAUSAL_RUNTIME_MODEL_IDS),
)
def test_causal_compile_with_subfunctions_all_models(model_type, model_id, tmp_path):
del model_type
try:
qeff_model = QEFFAutoModelForCausalLM.from_pretrained(model_id, trust_remote_code=True)
except Exception as exc:
_skip_on_model_fetch_error(exc, model_id)

try:
qpc = qeff_model.compile(
prefill_seq_len=8,
ctx_len=32,
use_onnx_subfunctions=True,
compile_dir=tmp_path / "compile-with-subfunctions",
)
except Exception as exc:
pytest.skip(
f"Skipping compile for {model_id}: compile backend unavailable or unsupported in this environment "
f"({type(exc).__name__}: {exc})"
)

qpc_path = Path(qpc)
assert qpc_path.name == "qpc"
assert qpc_path.is_dir()


@pytest.mark.llm_model
@pytest.mark.parametrize(
("model_type", "model_id"),
Expand Down
Loading