diff --git a/QEfficient/cloud/infer.py b/QEfficient/cloud/infer.py index d17ca26ff..4ebc37ec2 100644 --- a/QEfficient/cloud/infer.py +++ b/QEfficient/cloud/infer.py @@ -139,6 +139,7 @@ def main( qnn_config: Optional[str] = None, trust_remote_code: Optional[bool] = False, ccl_enabled: Optional[bool] = False, + use_onnx_subfunctions: bool = False, **kwargs, ) -> None: """ @@ -205,6 +206,8 @@ def main( Path of the QNN Config parameters file. Default is None. trust_remote_code : bool, optional If True, trusts remote code when loading models from HuggingFace. Default is False. + use_onnx_subfunctions : bool, optional + Enables ONNX subfunctions during export and compile. Default is False. **kwargs : Additional compiler options passed directly to `qaic-compile`. Any flag supported by `qaic-compile` can be passed. Parameters are converted to flags as follows: @@ -231,12 +234,14 @@ def main( """ cache_dir = check_and_assign_cache_dir(local_model_dir, cache_dir) - if "--mxfp6" in sys.argv: - if args.mxfp6: - logger.warning("mxfp6 is going to be deprecated in a future release, use -mxfp6_matmul instead.") - if "--mxint8" in sys.argv: - if args.mxint8: - logger.warning("mxint8 is going to be deprecated in a future release, use -mxint8_kv_cache instead.") + if "--mxfp6" in sys.argv and mxfp6: + logger.warning("mxfp6 is going to be deprecated in a future release, use -mxfp6_matmul instead.") + if "--mxint8" in sys.argv and mxint8: + logger.warning("mxint8 is going to be deprecated in a future release, use -mxint8_kv_cache instead.") + + qaic_config = {"ccl_enabled": True} if ccl_enabled else None + + qaic_config = {"ccl_enabled": True} if ccl_enabled else None qaic_config = {"ccl_enabled": True} if ccl_enabled else None @@ -280,6 +285,7 @@ def main( allow_mxint8_mdp_io=allow_mxint8_mdp_io, enable_qnn=enable_qnn, qnn_config=qnn_config, + use_onnx_subfunctions=use_onnx_subfunctions, **kwargs, ) @@ -382,6 +388,14 @@ def main( action="store_true", help="Compress Present/Past KV to MXINT8 using CustomIO config, default is False", ) + parser.add_argument( + "--use-onnx-subfunctions", + "--use_onnx_subfunctions", + dest="use_onnx_subfunctions", + action="store_true", + default=False, + help="Enable ONNX subfunctions during export/compile.", + ) parser.add_argument( "--num_cores", "--num-cores", type=int, required=True, help="Number of cores to compile on Cloud AI 100" ) diff --git a/QEfficient/proxy/__init__.py b/QEfficient/proxy/__init__.py new file mode 100644 index 000000000..410b674e5 --- /dev/null +++ b/QEfficient/proxy/__init__.py @@ -0,0 +1,13 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +# SPDX-License-Identifier: BSD-3-Clause +# +# ---------------------------------------------------------------------------- + +from QEfficient.proxy.proxy_transform import QeffProxyEmbedding, QeffProxyLinear + +__all__ = [ + "QeffProxyEmbedding", + "QeffProxyLinear", +] diff --git a/QEfficient/proxy/proxy_transform.py b/QEfficient/proxy/proxy_transform.py new file mode 100644 index 000000000..ec6af7d81 --- /dev/null +++ b/QEfficient/proxy/proxy_transform.py @@ -0,0 +1,27 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +# SPDX-License-Identifier: BSD-3-Clause +# +# ---------------------------------------------------------------------------- +import torch +from torch import nn + + +class QeffProxyEmbedding(nn.Module): + def __init__(self, num_embeddings, embedding_dim): + self.embed_tokens = None + self.num_embeddings = num_embeddings + self.embedding_dim = embedding_dim + + def forward(self, hidden_states, past_key_values_length=None): + inputs_embeds = torch.unsqueeze(hidden_states.float(), 2).expand(-1, -1, self.embedding_dim) + return inputs_embeds + + +class QeffProxyLinear(nn.Module): + def __init__(self, in_features, out_features, bias=False): + self.lm_head = None + + def forward(self, hidden_states): + return hidden_states diff --git a/QEfficient/proxy/pytorch_transform.py b/QEfficient/proxy/pytorch_transform.py new file mode 100644 index 000000000..ce68474cd --- /dev/null +++ b/QEfficient/proxy/pytorch_transform.py @@ -0,0 +1,22 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +# SPDX-License-Identifier: BSD-3-Clause +# +# ---------------------------------------------------------------------------- + +import torch.nn as nn + +from QEfficient.base.pytorch_transforms import ProxyModuleMappingTransform +from QEfficient.proxy import QeffProxyEmbedding, QeffProxyLinear + + +class QeffProxyModuleTransform(ProxyModuleMappingTransform): + """ + This transform is used to replace the original modules with QEfficient modules. + """ + + _module_mapping = { + nn.Embedding: QeffProxyEmbedding, + nn.Linear: QeffProxyLinear, + } diff --git a/QEfficient/transformers/models/gemma3/configs/gemma_updated_npi.yaml b/QEfficient/transformers/models/gemma3/configs/gemma_updated_npi.yaml new file mode 100644 index 000000000..faf4f9d72 --- /dev/null +++ b/QEfficient/transformers/models/gemma3/configs/gemma_updated_npi.yaml @@ -0,0 +1,1564 @@ +FP16NodeInstanceNames: + - /lm_head/MatMul_output_0 + - onnx::MatMul_25530 + +FP32NodeInstanceNames: + + + #Mul + - /language_model/layers.0/mlp/act_fn/Mul_output_0 + - /language_model/layers.0/mlp/act_fn/Mul_1_output_0 + - /language_model/layers.0/mlp/act_fn/Mul_2_output_0 + - /language_model/layers.0/mlp/act_fn/Mul_3_output_0 + - /language_model/layers.0/mlp/act_fn/Mul_4_output_0 + - /language_model/layers.0/mlp/act_fn/Mul_5_output_0 + - /language_model/layers.1/mlp/act_fn/Mul_output_0 + - /language_model/layers.1/mlp/act_fn/Mul_1_output_0 + - /language_model/layers.1/mlp/act_fn/Mul_2_output_0 + - /language_model/layers.1/mlp/act_fn/Mul_3_output_0 + - /language_model/layers.1/mlp/act_fn/Mul_4_output_0 + - /language_model/layers.1/mlp/act_fn/Mul_5_output_0 + - /language_model/layers.2/mlp/act_fn/Mul_output_0 + - /language_model/layers.2/mlp/act_fn/Mul_1_output_0 + - /language_model/layers.2/mlp/act_fn/Mul_2_output_0 + - /language_model/layers.2/mlp/act_fn/Mul_3_output_0 + - /language_model/layers.2/mlp/act_fn/Mul_4_output_0 + - /language_model/layers.2/mlp/act_fn/Mul_5_output_0 + - /language_model/layers.3/mlp/act_fn/Mul_output_0 + - /language_model/layers.3/mlp/act_fn/Mul_1_output_0 + - /language_model/layers.3/mlp/act_fn/Mul_2_output_0 + - /language_model/layers.3/mlp/act_fn/Mul_3_output_0 + - /language_model/layers.3/mlp/act_fn/Mul_4_output_0 + - /language_model/layers.3/mlp/act_fn/Mul_5_output_0 + - /language_model/layers.4/mlp/act_fn/Mul_output_0 + - /language_model/layers.4/mlp/act_fn/Mul_1_output_0 + - /language_model/layers.4/mlp/act_fn/Mul_2_output_0 + - /language_model/layers.4/mlp/act_fn/Mul_3_output_0 + - /language_model/layers.4/mlp/act_fn/Mul_4_output_0 + - /language_model/layers.4/mlp/act_fn/Mul_5_output_0 + - /language_model/layers.5/mlp/act_fn/Mul_output_0 + - /language_model/layers.5/mlp/act_fn/Mul_1_output_0 + - /language_model/layers.5/mlp/act_fn/Mul_2_output_0 + - /language_model/layers.5/mlp/act_fn/Mul_3_output_0 + - /language_model/layers.5/mlp/act_fn/Mul_4_output_0 + - /language_model/layers.5/mlp/act_fn/Mul_5_output_0 + - /language_model/layers.6/mlp/act_fn/Mul_output_0 + - /language_model/layers.6/mlp/act_fn/Mul_1_output_0 + - /language_model/layers.6/mlp/act_fn/Mul_2_output_0 + - /language_model/layers.6/mlp/act_fn/Mul_3_output_0 + - /language_model/layers.6/mlp/act_fn/Mul_4_output_0 + - /language_model/layers.6/mlp/act_fn/Mul_5_output_0 + - /language_model/layers.7/mlp/act_fn/Mul_output_0 + - /language_model/layers.7/mlp/act_fn/Mul_1_output_0 + - /language_model/layers.7/mlp/act_fn/Mul_2_output_0 + - /language_model/layers.7/mlp/act_fn/Mul_3_output_0 + - /language_model/layers.7/mlp/act_fn/Mul_4_output_0 + - /language_model/layers.7/mlp/act_fn/Mul_5_output_0 + - /language_model/layers.8/mlp/act_fn/Mul_output_0 + - /language_model/layers.8/mlp/act_fn/Mul_1_output_0 + - /language_model/layers.8/mlp/act_fn/Mul_2_output_0 + - /language_model/layers.8/mlp/act_fn/Mul_3_output_0 + - /language_model/layers.8/mlp/act_fn/Mul_4_output_0 + - /language_model/layers.8/mlp/act_fn/Mul_5_output_0 + - /language_model/layers.9/mlp/act_fn/Mul_output_0 + - /language_model/layers.9/mlp/act_fn/Mul_1_output_0 + - /language_model/layers.9/mlp/act_fn/Mul_2_output_0 + - /language_model/layers.9/mlp/act_fn/Mul_3_output_0 + - /language_model/layers.9/mlp/act_fn/Mul_4_output_0 + - /language_model/layers.9/mlp/act_fn/Mul_5_output_0 + - /language_model/layers.10/mlp/act_fn/Mul_output_0 + - /language_model/layers.10/mlp/act_fn/Mul_1_output_0 + - /language_model/layers.10/mlp/act_fn/Mul_2_output_0 + - /language_model/layers.10/mlp/act_fn/Mul_3_output_0 + - /language_model/layers.10/mlp/act_fn/Mul_4_output_0 + - /language_model/layers.10/mlp/act_fn/Mul_5_output_0 + - /language_model/layers.11/mlp/act_fn/Mul_output_0 + - /language_model/layers.11/mlp/act_fn/Mul_1_output_0 + - /language_model/layers.11/mlp/act_fn/Mul_2_output_0 + - /language_model/layers.11/mlp/act_fn/Mul_3_output_0 + - /language_model/layers.11/mlp/act_fn/Mul_4_output_0 + - /language_model/layers.11/mlp/act_fn/Mul_5_output_0 + - /language_model/layers.12/mlp/act_fn/Mul_output_0 + - /language_model/layers.12/mlp/act_fn/Mul_1_output_0 + - /language_model/layers.12/mlp/act_fn/Mul_2_output_0 + - /language_model/layers.12/mlp/act_fn/Mul_3_output_0 + - /language_model/layers.12/mlp/act_fn/Mul_4_output_0 + - /language_model/layers.12/mlp/act_fn/Mul_5_output_0 + - /language_model/layers.13/mlp/act_fn/Mul_output_0 + - /language_model/layers.13/mlp/act_fn/Mul_1_output_0 + - /language_model/layers.13/mlp/act_fn/Mul_2_output_0 + - /language_model/layers.13/mlp/act_fn/Mul_3_output_0 + - /language_model/layers.13/mlp/act_fn/Mul_4_output_0 + - /language_model/layers.13/mlp/act_fn/Mul_5_output_0 + - /language_model/layers.14/mlp/act_fn/Mul_output_0 + - /language_model/layers.14/mlp/act_fn/Mul_1_output_0 + - /language_model/layers.14/mlp/act_fn/Mul_2_output_0 + - /language_model/layers.14/mlp/act_fn/Mul_3_output_0 + - /language_model/layers.14/mlp/act_fn/Mul_4_output_0 + - /language_model/layers.14/mlp/act_fn/Mul_5_output_0 + - /language_model/layers.15/mlp/act_fn/Mul_output_0 + - /language_model/layers.15/mlp/act_fn/Mul_1_output_0 + - /language_model/layers.15/mlp/act_fn/Mul_2_output_0 + - /language_model/layers.15/mlp/act_fn/Mul_3_output_0 + - /language_model/layers.15/mlp/act_fn/Mul_4_output_0 + - /language_model/layers.15/mlp/act_fn/Mul_5_output_0 + - /language_model/layers.16/mlp/act_fn/Mul_output_0 + - /language_model/layers.16/mlp/act_fn/Mul_1_output_0 + - /language_model/layers.16/mlp/act_fn/Mul_2_output_0 + - /language_model/layers.16/mlp/act_fn/Mul_3_output_0 + - /language_model/layers.16/mlp/act_fn/Mul_4_output_0 + - /language_model/layers.16/mlp/act_fn/Mul_5_output_0 + - /language_model/layers.17/mlp/act_fn/Mul_output_0 + - /language_model/layers.17/mlp/act_fn/Mul_1_output_0 + - /language_model/layers.17/mlp/act_fn/Mul_2_output_0 + - /language_model/layers.17/mlp/act_fn/Mul_3_output_0 + - /language_model/layers.17/mlp/act_fn/Mul_4_output_0 + - /language_model/layers.17/mlp/act_fn/Mul_5_output_0 + - /language_model/layers.18/mlp/act_fn/Mul_output_0 + - /language_model/layers.18/mlp/act_fn/Mul_1_output_0 + - /language_model/layers.18/mlp/act_fn/Mul_2_output_0 + - /language_model/layers.18/mlp/act_fn/Mul_3_output_0 + - /language_model/layers.18/mlp/act_fn/Mul_4_output_0 + - /language_model/layers.18/mlp/act_fn/Mul_5_output_0 + - /language_model/layers.19/mlp/act_fn/Mul_output_0 + - /language_model/layers.19/mlp/act_fn/Mul_1_output_0 + - /language_model/layers.19/mlp/act_fn/Mul_2_output_0 + - /language_model/layers.19/mlp/act_fn/Mul_3_output_0 + - /language_model/layers.19/mlp/act_fn/Mul_4_output_0 + - /language_model/layers.19/mlp/act_fn/Mul_5_output_0 + - /language_model/layers.20/mlp/act_fn/Mul_output_0 + - /language_model/layers.20/mlp/act_fn/Mul_1_output_0 + - /language_model/layers.20/mlp/act_fn/Mul_2_output_0 + - /language_model/layers.20/mlp/act_fn/Mul_3_output_0 + - /language_model/layers.20/mlp/act_fn/Mul_4_output_0 + - /language_model/layers.20/mlp/act_fn/Mul_5_output_0 + - /language_model/layers.21/mlp/act_fn/Mul_output_0 + - /language_model/layers.21/mlp/act_fn/Mul_1_output_0 + - /language_model/layers.21/mlp/act_fn/Mul_2_output_0 + - /language_model/layers.21/mlp/act_fn/Mul_3_output_0 + - /language_model/layers.21/mlp/act_fn/Mul_4_output_0 + - /language_model/layers.21/mlp/act_fn/Mul_5_output_0 + - /language_model/layers.22/mlp/act_fn/Mul_output_0 + - /language_model/layers.22/mlp/act_fn/Mul_1_output_0 + - /language_model/layers.22/mlp/act_fn/Mul_2_output_0 + - /language_model/layers.22/mlp/act_fn/Mul_3_output_0 + - /language_model/layers.22/mlp/act_fn/Mul_4_output_0 + - /language_model/layers.22/mlp/act_fn/Mul_5_output_0 + - /language_model/layers.23/mlp/act_fn/Mul_output_0 + - /language_model/layers.23/mlp/act_fn/Mul_1_output_0 + - /language_model/layers.23/mlp/act_fn/Mul_2_output_0 + - /language_model/layers.23/mlp/act_fn/Mul_3_output_0 + - /language_model/layers.23/mlp/act_fn/Mul_4_output_0 + - /language_model/layers.23/mlp/act_fn/Mul_5_output_0 + - /language_model/layers.24/mlp/act_fn/Mul_output_0 + - /language_model/layers.24/mlp/act_fn/Mul_1_output_0 + - /language_model/layers.24/mlp/act_fn/Mul_2_output_0 + - /language_model/layers.24/mlp/act_fn/Mul_3_output_0 + - /language_model/layers.24/mlp/act_fn/Mul_4_output_0 + - /language_model/layers.24/mlp/act_fn/Mul_5_output_0 + - /language_model/layers.25/mlp/act_fn/Mul_output_0 + - /language_model/layers.25/mlp/act_fn/Mul_1_output_0 + - /language_model/layers.25/mlp/act_fn/Mul_2_output_0 + - /language_model/layers.25/mlp/act_fn/Mul_3_output_0 + - /language_model/layers.25/mlp/act_fn/Mul_4_output_0 + - /language_model/layers.25/mlp/act_fn/Mul_5_output_0 + - /language_model/layers.26/mlp/act_fn/Mul_output_0 + - /language_model/layers.26/mlp/act_fn/Mul_1_output_0 + - /language_model/layers.26/mlp/act_fn/Mul_2_output_0 + - /language_model/layers.26/mlp/act_fn/Mul_3_output_0 + - /language_model/layers.26/mlp/act_fn/Mul_4_output_0 + - /language_model/layers.26/mlp/act_fn/Mul_5_output_0 + - /language_model/layers.27/mlp/act_fn/Mul_output_0 + - /language_model/layers.27/mlp/act_fn/Mul_1_output_0 + - /language_model/layers.27/mlp/act_fn/Mul_2_output_0 + - /language_model/layers.27/mlp/act_fn/Mul_3_output_0 + - /language_model/layers.27/mlp/act_fn/Mul_4_output_0 + - /language_model/layers.27/mlp/act_fn/Mul_5_output_0 + - /language_model/layers.28/mlp/act_fn/Mul_output_0 + - /language_model/layers.28/mlp/act_fn/Mul_1_output_0 + - /language_model/layers.28/mlp/act_fn/Mul_2_output_0 + - /language_model/layers.28/mlp/act_fn/Mul_3_output_0 + - /language_model/layers.28/mlp/act_fn/Mul_4_output_0 + - /language_model/layers.28/mlp/act_fn/Mul_5_output_0 + - /language_model/layers.29/mlp/act_fn/Mul_output_0 + - /language_model/layers.29/mlp/act_fn/Mul_1_output_0 + - /language_model/layers.29/mlp/act_fn/Mul_2_output_0 + - /language_model/layers.29/mlp/act_fn/Mul_3_output_0 + - /language_model/layers.29/mlp/act_fn/Mul_4_output_0 + - /language_model/layers.29/mlp/act_fn/Mul_5_output_0 + - /language_model/layers.30/mlp/act_fn/Mul_output_0 + - /language_model/layers.30/mlp/act_fn/Mul_1_output_0 + - /language_model/layers.30/mlp/act_fn/Mul_2_output_0 + - /language_model/layers.30/mlp/act_fn/Mul_3_output_0 + - /language_model/layers.30/mlp/act_fn/Mul_4_output_0 + - /language_model/layers.30/mlp/act_fn/Mul_5_output_0 + - /language_model/layers.31/mlp/act_fn/Mul_output_0 + - /language_model/layers.31/mlp/act_fn/Mul_1_output_0 + - /language_model/layers.31/mlp/act_fn/Mul_2_output_0 + - /language_model/layers.31/mlp/act_fn/Mul_3_output_0 + - /language_model/layers.31/mlp/act_fn/Mul_4_output_0 + - /language_model/layers.31/mlp/act_fn/Mul_5_output_0 + - /language_model/layers.32/mlp/act_fn/Mul_output_0 + - /language_model/layers.32/mlp/act_fn/Mul_1_output_0 + - /language_model/layers.32/mlp/act_fn/Mul_2_output_0 + - /language_model/layers.32/mlp/act_fn/Mul_3_output_0 + - /language_model/layers.32/mlp/act_fn/Mul_4_output_0 + - /language_model/layers.32/mlp/act_fn/Mul_5_output_0 + - /language_model/layers.33/mlp/act_fn/Mul_output_0 + - /language_model/layers.33/mlp/act_fn/Mul_1_output_0 + - /language_model/layers.33/mlp/act_fn/Mul_2_output_0 + - /language_model/layers.33/mlp/act_fn/Mul_3_output_0 + - /language_model/layers.33/mlp/act_fn/Mul_4_output_0 + - /language_model/layers.33/mlp/act_fn/Mul_5_output_0 + - /language_model/layers.34/mlp/act_fn/Mul_output_0 + - /language_model/layers.34/mlp/act_fn/Mul_1_output_0 + - /language_model/layers.34/mlp/act_fn/Mul_2_output_0 + - /language_model/layers.34/mlp/act_fn/Mul_3_output_0 + - /language_model/layers.34/mlp/act_fn/Mul_4_output_0 + - /language_model/layers.34/mlp/act_fn/Mul_5_output_0 + - /language_model/layers.35/mlp/act_fn/Mul_output_0 + - /language_model/layers.35/mlp/act_fn/Mul_1_output_0 + - /language_model/layers.35/mlp/act_fn/Mul_2_output_0 + - /language_model/layers.35/mlp/act_fn/Mul_3_output_0 + - /language_model/layers.35/mlp/act_fn/Mul_4_output_0 + - /language_model/layers.35/mlp/act_fn/Mul_5_output_0 + - /language_model/layers.36/mlp/act_fn/Mul_output_0 + - /language_model/layers.36/mlp/act_fn/Mul_1_output_0 + - /language_model/layers.36/mlp/act_fn/Mul_2_output_0 + - /language_model/layers.36/mlp/act_fn/Mul_3_output_0 + - /language_model/layers.36/mlp/act_fn/Mul_4_output_0 + - /language_model/layers.36/mlp/act_fn/Mul_5_output_0 + - /language_model/layers.37/mlp/act_fn/Mul_output_0 + - /language_model/layers.37/mlp/act_fn/Mul_1_output_0 + - /language_model/layers.37/mlp/act_fn/Mul_2_output_0 + - /language_model/layers.37/mlp/act_fn/Mul_3_output_0 + - /language_model/layers.37/mlp/act_fn/Mul_4_output_0 + - /language_model/layers.37/mlp/act_fn/Mul_5_output_0 + - /language_model/layers.38/mlp/act_fn/Mul_output_0 + - /language_model/layers.38/mlp/act_fn/Mul_1_output_0 + - /language_model/layers.38/mlp/act_fn/Mul_2_output_0 + - /language_model/layers.38/mlp/act_fn/Mul_3_output_0 + - /language_model/layers.38/mlp/act_fn/Mul_4_output_0 + - /language_model/layers.38/mlp/act_fn/Mul_5_output_0 + - /language_model/layers.39/mlp/act_fn/Mul_output_0 + - /language_model/layers.39/mlp/act_fn/Mul_1_output_0 + - /language_model/layers.39/mlp/act_fn/Mul_2_output_0 + - /language_model/layers.39/mlp/act_fn/Mul_3_output_0 + - /language_model/layers.39/mlp/act_fn/Mul_4_output_0 + - /language_model/layers.39/mlp/act_fn/Mul_5_output_0 + - /language_model/layers.40/mlp/act_fn/Mul_output_0 + - /language_model/layers.40/mlp/act_fn/Mul_1_output_0 + - /language_model/layers.40/mlp/act_fn/Mul_2_output_0 + - /language_model/layers.40/mlp/act_fn/Mul_3_output_0 + - /language_model/layers.40/mlp/act_fn/Mul_4_output_0 + - /language_model/layers.40/mlp/act_fn/Mul_5_output_0 + - /language_model/layers.41/mlp/act_fn/Mul_output_0 + - /language_model/layers.41/mlp/act_fn/Mul_1_output_0 + - /language_model/layers.41/mlp/act_fn/Mul_2_output_0 + - /language_model/layers.41/mlp/act_fn/Mul_3_output_0 + - /language_model/layers.41/mlp/act_fn/Mul_4_output_0 + - /language_model/layers.41/mlp/act_fn/Mul_5_output_0 + - /language_model/layers.42/mlp/act_fn/Mul_output_0 + - /language_model/layers.42/mlp/act_fn/Mul_1_output_0 + - /language_model/layers.42/mlp/act_fn/Mul_2_output_0 + - /language_model/layers.42/mlp/act_fn/Mul_3_output_0 + - /language_model/layers.42/mlp/act_fn/Mul_4_output_0 + - /language_model/layers.42/mlp/act_fn/Mul_5_output_0 + - /language_model/layers.43/mlp/act_fn/Mul_output_0 + - /language_model/layers.43/mlp/act_fn/Mul_1_output_0 + - /language_model/layers.43/mlp/act_fn/Mul_2_output_0 + - /language_model/layers.43/mlp/act_fn/Mul_3_output_0 + - /language_model/layers.43/mlp/act_fn/Mul_4_output_0 + - /language_model/layers.43/mlp/act_fn/Mul_5_output_0 + - /language_model/layers.44/mlp/act_fn/Mul_output_0 + - /language_model/layers.44/mlp/act_fn/Mul_1_output_0 + - /language_model/layers.44/mlp/act_fn/Mul_2_output_0 + - /language_model/layers.44/mlp/act_fn/Mul_3_output_0 + - /language_model/layers.44/mlp/act_fn/Mul_4_output_0 + - /language_model/layers.44/mlp/act_fn/Mul_5_output_0 + - /language_model/layers.45/mlp/act_fn/Mul_output_0 + - /language_model/layers.45/mlp/act_fn/Mul_1_output_0 + - /language_model/layers.45/mlp/act_fn/Mul_2_output_0 + - /language_model/layers.45/mlp/act_fn/Mul_3_output_0 + - /language_model/layers.45/mlp/act_fn/Mul_4_output_0 + - /language_model/layers.45/mlp/act_fn/Mul_5_output_0 + - /language_model/layers.46/mlp/act_fn/Mul_output_0 + - /language_model/layers.46/mlp/act_fn/Mul_1_output_0 + - /language_model/layers.46/mlp/act_fn/Mul_2_output_0 + - /language_model/layers.46/mlp/act_fn/Mul_3_output_0 + - /language_model/layers.46/mlp/act_fn/Mul_4_output_0 + - /language_model/layers.46/mlp/act_fn/Mul_5_output_0 + - /language_model/layers.47/mlp/act_fn/Mul_output_0 + - /language_model/layers.47/mlp/act_fn/Mul_1_output_0 + - /language_model/layers.47/mlp/act_fn/Mul_2_output_0 + - /language_model/layers.47/mlp/act_fn/Mul_3_output_0 + - /language_model/layers.47/mlp/act_fn/Mul_4_output_0 + - /language_model/layers.47/mlp/act_fn/Mul_5_output_0 + - /language_model/layers.48/mlp/act_fn/Mul_output_0 + - /language_model/layers.48/mlp/act_fn/Mul_1_output_0 + - /language_model/layers.48/mlp/act_fn/Mul_2_output_0 + - /language_model/layers.48/mlp/act_fn/Mul_3_output_0 + - /language_model/layers.48/mlp/act_fn/Mul_4_output_0 + - /language_model/layers.48/mlp/act_fn/Mul_5_output_0 + - /language_model/layers.49/mlp/act_fn/Mul_output_0 + - /language_model/layers.49/mlp/act_fn/Mul_1_output_0 + - /language_model/layers.49/mlp/act_fn/Mul_2_output_0 + - /language_model/layers.49/mlp/act_fn/Mul_3_output_0 + - /language_model/layers.49/mlp/act_fn/Mul_4_output_0 + - /language_model/layers.49/mlp/act_fn/Mul_5_output_0 + - /language_model/layers.50/mlp/act_fn/Mul_output_0 + - /language_model/layers.50/mlp/act_fn/Mul_1_output_0 + - /language_model/layers.50/mlp/act_fn/Mul_2_output_0 + - /language_model/layers.50/mlp/act_fn/Mul_3_output_0 + - /language_model/layers.50/mlp/act_fn/Mul_4_output_0 + - /language_model/layers.50/mlp/act_fn/Mul_5_output_0 + - /language_model/layers.51/mlp/act_fn/Mul_output_0 + - /language_model/layers.51/mlp/act_fn/Mul_1_output_0 + - /language_model/layers.51/mlp/act_fn/Mul_2_output_0 + - /language_model/layers.51/mlp/act_fn/Mul_3_output_0 + - /language_model/layers.51/mlp/act_fn/Mul_4_output_0 + - /language_model/layers.51/mlp/act_fn/Mul_5_output_0 + - /language_model/layers.52/mlp/act_fn/Mul_output_0 + - /language_model/layers.52/mlp/act_fn/Mul_1_output_0 + - /language_model/layers.52/mlp/act_fn/Mul_2_output_0 + - /language_model/layers.52/mlp/act_fn/Mul_3_output_0 + - /language_model/layers.52/mlp/act_fn/Mul_4_output_0 + - /language_model/layers.52/mlp/act_fn/Mul_5_output_0 + - /language_model/layers.53/mlp/act_fn/Mul_output_0 + - /language_model/layers.53/mlp/act_fn/Mul_1_output_0 + - /language_model/layers.53/mlp/act_fn/Mul_2_output_0 + - /language_model/layers.53/mlp/act_fn/Mul_3_output_0 + - /language_model/layers.53/mlp/act_fn/Mul_4_output_0 + - /language_model/layers.53/mlp/act_fn/Mul_5_output_0 + - /language_model/layers.54/mlp/act_fn/Mul_output_0 + - /language_model/layers.54/mlp/act_fn/Mul_1_output_0 + - /language_model/layers.54/mlp/act_fn/Mul_2_output_0 + - /language_model/layers.54/mlp/act_fn/Mul_3_output_0 + - /language_model/layers.54/mlp/act_fn/Mul_4_output_0 + - /language_model/layers.54/mlp/act_fn/Mul_5_output_0 + - /language_model/layers.55/mlp/act_fn/Mul_output_0 + - /language_model/layers.55/mlp/act_fn/Mul_1_output_0 + - /language_model/layers.55/mlp/act_fn/Mul_2_output_0 + - /language_model/layers.55/mlp/act_fn/Mul_3_output_0 + - /language_model/layers.55/mlp/act_fn/Mul_4_output_0 + - /language_model/layers.55/mlp/act_fn/Mul_5_output_0 + - /language_model/layers.56/mlp/act_fn/Mul_output_0 + - /language_model/layers.56/mlp/act_fn/Mul_1_output_0 + - /language_model/layers.56/mlp/act_fn/Mul_2_output_0 + - /language_model/layers.56/mlp/act_fn/Mul_3_output_0 + - /language_model/layers.56/mlp/act_fn/Mul_4_output_0 + - /language_model/layers.56/mlp/act_fn/Mul_5_output_0 + - /language_model/layers.57/mlp/act_fn/Mul_output_0 + - /language_model/layers.57/mlp/act_fn/Mul_1_output_0 + - /language_model/layers.57/mlp/act_fn/Mul_2_output_0 + - /language_model/layers.57/mlp/act_fn/Mul_3_output_0 + - /language_model/layers.57/mlp/act_fn/Mul_4_output_0 + - /language_model/layers.57/mlp/act_fn/Mul_5_output_0 + - /language_model/layers.58/mlp/act_fn/Mul_output_0 + - /language_model/layers.58/mlp/act_fn/Mul_1_output_0 + - /language_model/layers.58/mlp/act_fn/Mul_2_output_0 + - /language_model/layers.58/mlp/act_fn/Mul_3_output_0 + - /language_model/layers.58/mlp/act_fn/Mul_4_output_0 + - /language_model/layers.58/mlp/act_fn/Mul_5_output_0 + - /language_model/layers.59/mlp/act_fn/Mul_output_0 + - /language_model/layers.59/mlp/act_fn/Mul_1_output_0 + - /language_model/layers.59/mlp/act_fn/Mul_2_output_0 + - /language_model/layers.59/mlp/act_fn/Mul_3_output_0 + - /language_model/layers.59/mlp/act_fn/Mul_4_output_0 + - /language_model/layers.59/mlp/act_fn/Mul_5_output_0 + - /language_model/layers.60/mlp/act_fn/Mul_output_0 + - /language_model/layers.60/mlp/act_fn/Mul_1_output_0 + - /language_model/layers.60/mlp/act_fn/Mul_2_output_0 + - /language_model/layers.60/mlp/act_fn/Mul_3_output_0 + - /language_model/layers.60/mlp/act_fn/Mul_4_output_0 + - /language_model/layers.60/mlp/act_fn/Mul_5_output_0 + - /language_model/layers.61/mlp/act_fn/Mul_output_0 + - /language_model/layers.61/mlp/act_fn/Mul_1_output_0 + - /language_model/layers.61/mlp/act_fn/Mul_2_output_0 + - /language_model/layers.61/mlp/act_fn/Mul_3_output_0 + - /language_model/layers.61/mlp/act_fn/Mul_4_output_0 + - /language_model/layers.61/mlp/act_fn/Mul_5_output_0 + + #Constant + - /language_model/layers.0/mlp/act_fn/Constant_output_0 + - /language_model/layers.0/mlp/act_fn/Constant_1_output_0 + - /language_model/layers.0/mlp/act_fn/Constant_2_output_0 + - /language_model/layers.0/mlp/act_fn/Constant_3_output_0 + - /language_model/layers.1/mlp/act_fn/Constant_output_0 + - /language_model/layers.1/mlp/act_fn/Constant_1_output_0 + - /language_model/layers.1/mlp/act_fn/Constant_2_output_0 + - /language_model/layers.1/mlp/act_fn/Constant_3_output_0 + - /language_model/layers.2/mlp/act_fn/Constant_output_0 + - /language_model/layers.2/mlp/act_fn/Constant_1_output_0 + - /language_model/layers.2/mlp/act_fn/Constant_2_output_0 + - /language_model/layers.2/mlp/act_fn/Constant_3_output_0 + - /language_model/layers.3/mlp/act_fn/Constant_output_0 + - /language_model/layers.3/mlp/act_fn/Constant_1_output_0 + - /language_model/layers.3/mlp/act_fn/Constant_2_output_0 + - /language_model/layers.3/mlp/act_fn/Constant_3_output_0 + - /language_model/layers.4/mlp/act_fn/Constant_output_0 + - /language_model/layers.4/mlp/act_fn/Constant_1_output_0 + - /language_model/layers.4/mlp/act_fn/Constant_2_output_0 + - /language_model/layers.4/mlp/act_fn/Constant_3_output_0 + - /language_model/layers.5/mlp/act_fn/Constant_output_0 + - /language_model/layers.5/mlp/act_fn/Constant_1_output_0 + - /language_model/layers.5/mlp/act_fn/Constant_2_output_0 + - /language_model/layers.5/mlp/act_fn/Constant_3_output_0 + - /language_model/layers.6/mlp/act_fn/Constant_output_0 + - /language_model/layers.6/mlp/act_fn/Constant_1_output_0 + - /language_model/layers.6/mlp/act_fn/Constant_2_output_0 + - /language_model/layers.6/mlp/act_fn/Constant_3_output_0 + - /language_model/layers.7/mlp/act_fn/Constant_output_0 + - /language_model/layers.7/mlp/act_fn/Constant_1_output_0 + - /language_model/layers.7/mlp/act_fn/Constant_2_output_0 + - /language_model/layers.7/mlp/act_fn/Constant_3_output_0 + - /language_model/layers.8/mlp/act_fn/Constant_output_0 + - /language_model/layers.8/mlp/act_fn/Constant_1_output_0 + - /language_model/layers.8/mlp/act_fn/Constant_2_output_0 + - /language_model/layers.8/mlp/act_fn/Constant_3_output_0 + - /language_model/layers.9/mlp/act_fn/Constant_output_0 + - /language_model/layers.9/mlp/act_fn/Constant_1_output_0 + - /language_model/layers.9/mlp/act_fn/Constant_2_output_0 + - /language_model/layers.9/mlp/act_fn/Constant_3_output_0 + - /language_model/layers.10/mlp/act_fn/Constant_output_0 + - /language_model/layers.10/mlp/act_fn/Constant_1_output_0 + - /language_model/layers.10/mlp/act_fn/Constant_2_output_0 + - /language_model/layers.10/mlp/act_fn/Constant_3_output_0 + - /language_model/layers.11/mlp/act_fn/Constant_output_0 + - /language_model/layers.11/mlp/act_fn/Constant_1_output_0 + - /language_model/layers.11/mlp/act_fn/Constant_2_output_0 + - /language_model/layers.11/mlp/act_fn/Constant_3_output_0 + - /language_model/layers.12/mlp/act_fn/Constant_output_0 + - /language_model/layers.12/mlp/act_fn/Constant_1_output_0 + - /language_model/layers.12/mlp/act_fn/Constant_2_output_0 + - /language_model/layers.12/mlp/act_fn/Constant_3_output_0 + - /language_model/layers.13/mlp/act_fn/Constant_output_0 + - /language_model/layers.13/mlp/act_fn/Constant_1_output_0 + - /language_model/layers.13/mlp/act_fn/Constant_2_output_0 + - /language_model/layers.13/mlp/act_fn/Constant_3_output_0 + - /language_model/layers.14/mlp/act_fn/Constant_output_0 + - /language_model/layers.14/mlp/act_fn/Constant_1_output_0 + - /language_model/layers.14/mlp/act_fn/Constant_2_output_0 + - /language_model/layers.14/mlp/act_fn/Constant_3_output_0 + - /language_model/layers.15/mlp/act_fn/Constant_output_0 + - /language_model/layers.15/mlp/act_fn/Constant_1_output_0 + - /language_model/layers.15/mlp/act_fn/Constant_2_output_0 + - /language_model/layers.15/mlp/act_fn/Constant_3_output_0 + - /language_model/layers.16/mlp/act_fn/Constant_output_0 + - /language_model/layers.16/mlp/act_fn/Constant_1_output_0 + - /language_model/layers.16/mlp/act_fn/Constant_2_output_0 + - /language_model/layers.16/mlp/act_fn/Constant_3_output_0 + - /language_model/layers.17/mlp/act_fn/Constant_output_0 + - /language_model/layers.17/mlp/act_fn/Constant_1_output_0 + - /language_model/layers.17/mlp/act_fn/Constant_2_output_0 + - /language_model/layers.17/mlp/act_fn/Constant_3_output_0 + - /language_model/layers.18/mlp/act_fn/Constant_output_0 + - /language_model/layers.18/mlp/act_fn/Constant_1_output_0 + - /language_model/layers.18/mlp/act_fn/Constant_2_output_0 + - /language_model/layers.18/mlp/act_fn/Constant_3_output_0 + - /language_model/layers.19/mlp/act_fn/Constant_output_0 + - /language_model/layers.19/mlp/act_fn/Constant_1_output_0 + - /language_model/layers.19/mlp/act_fn/Constant_2_output_0 + - /language_model/layers.19/mlp/act_fn/Constant_3_output_0 + - /language_model/layers.20/mlp/act_fn/Constant_output_0 + - /language_model/layers.20/mlp/act_fn/Constant_1_output_0 + - /language_model/layers.20/mlp/act_fn/Constant_2_output_0 + - /language_model/layers.20/mlp/act_fn/Constant_3_output_0 + - /language_model/layers.21/mlp/act_fn/Constant_output_0 + - /language_model/layers.21/mlp/act_fn/Constant_1_output_0 + - /language_model/layers.21/mlp/act_fn/Constant_2_output_0 + - /language_model/layers.21/mlp/act_fn/Constant_3_output_0 + - /language_model/layers.22/mlp/act_fn/Constant_output_0 + - /language_model/layers.22/mlp/act_fn/Constant_1_output_0 + - /language_model/layers.22/mlp/act_fn/Constant_2_output_0 + - /language_model/layers.22/mlp/act_fn/Constant_3_output_0 + - /language_model/layers.23/mlp/act_fn/Constant_output_0 + - /language_model/layers.23/mlp/act_fn/Constant_1_output_0 + - /language_model/layers.23/mlp/act_fn/Constant_2_output_0 + - /language_model/layers.23/mlp/act_fn/Constant_3_output_0 + - /language_model/layers.24/mlp/act_fn/Constant_output_0 + - /language_model/layers.24/mlp/act_fn/Constant_1_output_0 + - /language_model/layers.24/mlp/act_fn/Constant_2_output_0 + - /language_model/layers.24/mlp/act_fn/Constant_3_output_0 + - /language_model/layers.25/mlp/act_fn/Constant_output_0 + - /language_model/layers.25/mlp/act_fn/Constant_1_output_0 + - /language_model/layers.25/mlp/act_fn/Constant_2_output_0 + - /language_model/layers.25/mlp/act_fn/Constant_3_output_0 + - /language_model/layers.26/mlp/act_fn/Constant_output_0 + - /language_model/layers.26/mlp/act_fn/Constant_1_output_0 + - /language_model/layers.26/mlp/act_fn/Constant_2_output_0 + - /language_model/layers.26/mlp/act_fn/Constant_3_output_0 + - /language_model/layers.27/mlp/act_fn/Constant_output_0 + - /language_model/layers.27/mlp/act_fn/Constant_1_output_0 + - /language_model/layers.27/mlp/act_fn/Constant_2_output_0 + - /language_model/layers.27/mlp/act_fn/Constant_3_output_0 + - /language_model/layers.28/mlp/act_fn/Constant_output_0 + - /language_model/layers.28/mlp/act_fn/Constant_1_output_0 + - /language_model/layers.28/mlp/act_fn/Constant_2_output_0 + - /language_model/layers.28/mlp/act_fn/Constant_3_output_0 + - /language_model/layers.29/mlp/act_fn/Constant_output_0 + - /language_model/layers.29/mlp/act_fn/Constant_1_output_0 + - /language_model/layers.29/mlp/act_fn/Constant_2_output_0 + - /language_model/layers.29/mlp/act_fn/Constant_3_output_0 + - /language_model/layers.30/mlp/act_fn/Constant_output_0 + - /language_model/layers.30/mlp/act_fn/Constant_1_output_0 + - /language_model/layers.30/mlp/act_fn/Constant_2_output_0 + - /language_model/layers.30/mlp/act_fn/Constant_3_output_0 + - /language_model/layers.31/mlp/act_fn/Constant_output_0 + - /language_model/layers.31/mlp/act_fn/Constant_1_output_0 + - /language_model/layers.31/mlp/act_fn/Constant_2_output_0 + - /language_model/layers.31/mlp/act_fn/Constant_3_output_0 + - /language_model/layers.32/mlp/act_fn/Constant_output_0 + - /language_model/layers.32/mlp/act_fn/Constant_1_output_0 + - /language_model/layers.32/mlp/act_fn/Constant_2_output_0 + - /language_model/layers.32/mlp/act_fn/Constant_3_output_0 + - /language_model/layers.33/mlp/act_fn/Constant_output_0 + - /language_model/layers.33/mlp/act_fn/Constant_1_output_0 + - /language_model/layers.33/mlp/act_fn/Constant_2_output_0 + - /language_model/layers.33/mlp/act_fn/Constant_3_output_0 + - /language_model/layers.34/mlp/act_fn/Constant_output_0 + - /language_model/layers.34/mlp/act_fn/Constant_1_output_0 + - /language_model/layers.34/mlp/act_fn/Constant_2_output_0 + - /language_model/layers.34/mlp/act_fn/Constant_3_output_0 + - /language_model/layers.35/mlp/act_fn/Constant_output_0 + - /language_model/layers.35/mlp/act_fn/Constant_1_output_0 + - /language_model/layers.35/mlp/act_fn/Constant_2_output_0 + - /language_model/layers.35/mlp/act_fn/Constant_3_output_0 + - /language_model/layers.36/mlp/act_fn/Constant_output_0 + - /language_model/layers.36/mlp/act_fn/Constant_1_output_0 + - /language_model/layers.36/mlp/act_fn/Constant_2_output_0 + - /language_model/layers.36/mlp/act_fn/Constant_3_output_0 + - /language_model/layers.37/mlp/act_fn/Constant_output_0 + - /language_model/layers.37/mlp/act_fn/Constant_1_output_0 + - /language_model/layers.37/mlp/act_fn/Constant_2_output_0 + - /language_model/layers.37/mlp/act_fn/Constant_3_output_0 + - /language_model/layers.38/mlp/act_fn/Constant_output_0 + - /language_model/layers.38/mlp/act_fn/Constant_1_output_0 + - /language_model/layers.38/mlp/act_fn/Constant_2_output_0 + - /language_model/layers.38/mlp/act_fn/Constant_3_output_0 + - /language_model/layers.39/mlp/act_fn/Constant_output_0 + - /language_model/layers.39/mlp/act_fn/Constant_1_output_0 + - /language_model/layers.39/mlp/act_fn/Constant_2_output_0 + - /language_model/layers.39/mlp/act_fn/Constant_3_output_0 + - /language_model/layers.40/mlp/act_fn/Constant_output_0 + - /language_model/layers.40/mlp/act_fn/Constant_1_output_0 + - /language_model/layers.40/mlp/act_fn/Constant_2_output_0 + - /language_model/layers.40/mlp/act_fn/Constant_3_output_0 + - /language_model/layers.41/mlp/act_fn/Constant_output_0 + - /language_model/layers.41/mlp/act_fn/Constant_1_output_0 + - /language_model/layers.41/mlp/act_fn/Constant_2_output_0 + - /language_model/layers.41/mlp/act_fn/Constant_3_output_0 + - /language_model/layers.42/mlp/act_fn/Constant_output_0 + - /language_model/layers.42/mlp/act_fn/Constant_1_output_0 + - /language_model/layers.42/mlp/act_fn/Constant_2_output_0 + - /language_model/layers.42/mlp/act_fn/Constant_3_output_0 + - /language_model/layers.43/mlp/act_fn/Constant_output_0 + - /language_model/layers.43/mlp/act_fn/Constant_1_output_0 + - /language_model/layers.43/mlp/act_fn/Constant_2_output_0 + - /language_model/layers.43/mlp/act_fn/Constant_3_output_0 + - /language_model/layers.44/mlp/act_fn/Constant_output_0 + - /language_model/layers.44/mlp/act_fn/Constant_1_output_0 + - /language_model/layers.44/mlp/act_fn/Constant_2_output_0 + - /language_model/layers.44/mlp/act_fn/Constant_3_output_0 + - /language_model/layers.45/mlp/act_fn/Constant_output_0 + - /language_model/layers.45/mlp/act_fn/Constant_1_output_0 + - /language_model/layers.45/mlp/act_fn/Constant_2_output_0 + - /language_model/layers.45/mlp/act_fn/Constant_3_output_0 + - /language_model/layers.46/mlp/act_fn/Constant_output_0 + - /language_model/layers.46/mlp/act_fn/Constant_1_output_0 + - /language_model/layers.46/mlp/act_fn/Constant_2_output_0 + - /language_model/layers.46/mlp/act_fn/Constant_3_output_0 + - /language_model/layers.47/mlp/act_fn/Constant_output_0 + - /language_model/layers.47/mlp/act_fn/Constant_1_output_0 + - /language_model/layers.47/mlp/act_fn/Constant_2_output_0 + - /language_model/layers.47/mlp/act_fn/Constant_3_output_0 + - /language_model/layers.48/mlp/act_fn/Constant_output_0 + - /language_model/layers.48/mlp/act_fn/Constant_1_output_0 + - /language_model/layers.48/mlp/act_fn/Constant_2_output_0 + - /language_model/layers.48/mlp/act_fn/Constant_3_output_0 + - /language_model/layers.49/mlp/act_fn/Constant_output_0 + - /language_model/layers.49/mlp/act_fn/Constant_1_output_0 + - /language_model/layers.49/mlp/act_fn/Constant_2_output_0 + - /language_model/layers.49/mlp/act_fn/Constant_3_output_0 + - /language_model/layers.50/mlp/act_fn/Constant_output_0 + - /language_model/layers.50/mlp/act_fn/Constant_1_output_0 + - /language_model/layers.50/mlp/act_fn/Constant_2_output_0 + - /language_model/layers.50/mlp/act_fn/Constant_3_output_0 + - /language_model/layers.51/mlp/act_fn/Constant_output_0 + - /language_model/layers.51/mlp/act_fn/Constant_1_output_0 + - /language_model/layers.51/mlp/act_fn/Constant_2_output_0 + - /language_model/layers.51/mlp/act_fn/Constant_3_output_0 + - /language_model/layers.52/mlp/act_fn/Constant_output_0 + - /language_model/layers.52/mlp/act_fn/Constant_1_output_0 + - /language_model/layers.52/mlp/act_fn/Constant_2_output_0 + - /language_model/layers.52/mlp/act_fn/Constant_3_output_0 + - /language_model/layers.53/mlp/act_fn/Constant_output_0 + - /language_model/layers.53/mlp/act_fn/Constant_1_output_0 + - /language_model/layers.53/mlp/act_fn/Constant_2_output_0 + - /language_model/layers.53/mlp/act_fn/Constant_3_output_0 + - /language_model/layers.54/mlp/act_fn/Constant_output_0 + - /language_model/layers.54/mlp/act_fn/Constant_1_output_0 + - /language_model/layers.54/mlp/act_fn/Constant_2_output_0 + - /language_model/layers.54/mlp/act_fn/Constant_3_output_0 + - /language_model/layers.55/mlp/act_fn/Constant_output_0 + - /language_model/layers.55/mlp/act_fn/Constant_1_output_0 + - /language_model/layers.55/mlp/act_fn/Constant_2_output_0 + - /language_model/layers.55/mlp/act_fn/Constant_3_output_0 + - /language_model/layers.56/mlp/act_fn/Constant_output_0 + - /language_model/layers.56/mlp/act_fn/Constant_1_output_0 + - /language_model/layers.56/mlp/act_fn/Constant_2_output_0 + - /language_model/layers.56/mlp/act_fn/Constant_3_output_0 + - /language_model/layers.57/mlp/act_fn/Constant_output_0 + - /language_model/layers.57/mlp/act_fn/Constant_1_output_0 + - /language_model/layers.57/mlp/act_fn/Constant_2_output_0 + - /language_model/layers.57/mlp/act_fn/Constant_3_output_0 + - /language_model/layers.58/mlp/act_fn/Constant_output_0 + - /language_model/layers.58/mlp/act_fn/Constant_1_output_0 + - /language_model/layers.58/mlp/act_fn/Constant_2_output_0 + - /language_model/layers.58/mlp/act_fn/Constant_3_output_0 + - /language_model/layers.59/mlp/act_fn/Constant_output_0 + - /language_model/layers.59/mlp/act_fn/Constant_1_output_0 + - /language_model/layers.59/mlp/act_fn/Constant_2_output_0 + - /language_model/layers.59/mlp/act_fn/Constant_3_output_0 + - /language_model/layers.60/mlp/act_fn/Constant_output_0 + - /language_model/layers.60/mlp/act_fn/Constant_1_output_0 + - /language_model/layers.60/mlp/act_fn/Constant_2_output_0 + - /language_model/layers.60/mlp/act_fn/Constant_3_output_0 + - /language_model/layers.61/mlp/act_fn/Constant_output_0 + - /language_model/layers.61/mlp/act_fn/Constant_1_output_0 + - /language_model/layers.61/mlp/act_fn/Constant_2_output_0 + - /language_model/layers.61/mlp/act_fn/Constant_3_output_0 + + #Add + - /language_model/layers.0/mlp/act_fn/Add_output_0 + - /language_model/layers.0/mlp/act_fn/Add_1_output_0 + - /language_model/layers.1/mlp/act_fn/Add_output_0 + - /language_model/layers.1/mlp/act_fn/Add_1_output_0 + - /language_model/layers.2/mlp/act_fn/Add_output_0 + - /language_model/layers.2/mlp/act_fn/Add_1_output_0 + - /language_model/layers.3/mlp/act_fn/Add_output_0 + - /language_model/layers.3/mlp/act_fn/Add_1_output_0 + - /language_model/layers.4/mlp/act_fn/Add_output_0 + - /language_model/layers.4/mlp/act_fn/Add_1_output_0 + - /language_model/layers.5/mlp/act_fn/Add_output_0 + - /language_model/layers.5/mlp/act_fn/Add_1_output_0 + - /language_model/layers.6/mlp/act_fn/Add_output_0 + - /language_model/layers.6/mlp/act_fn/Add_1_output_0 + - /language_model/layers.7/mlp/act_fn/Add_output_0 + - /language_model/layers.7/mlp/act_fn/Add_1_output_0 + - /language_model/layers.8/mlp/act_fn/Add_output_0 + - /language_model/layers.8/mlp/act_fn/Add_1_output_0 + - /language_model/layers.9/mlp/act_fn/Add_output_0 + - /language_model/layers.9/mlp/act_fn/Add_1_output_0 + - /language_model/layers.10/mlp/act_fn/Add_output_0 + - /language_model/layers.10/mlp/act_fn/Add_1_output_0 + - /language_model/layers.11/mlp/act_fn/Add_output_0 + - /language_model/layers.11/mlp/act_fn/Add_1_output_0 + - /language_model/layers.12/mlp/act_fn/Add_output_0 + - /language_model/layers.12/mlp/act_fn/Add_1_output_0 + - /language_model/layers.13/mlp/act_fn/Add_output_0 + - /language_model/layers.13/mlp/act_fn/Add_1_output_0 + - /language_model/layers.14/mlp/act_fn/Add_output_0 + - /language_model/layers.14/mlp/act_fn/Add_1_output_0 + - /language_model/layers.15/mlp/act_fn/Add_output_0 + - /language_model/layers.15/mlp/act_fn/Add_1_output_0 + - /language_model/layers.16/mlp/act_fn/Add_output_0 + - /language_model/layers.16/mlp/act_fn/Add_1_output_0 + - /language_model/layers.17/mlp/act_fn/Add_output_0 + - /language_model/layers.17/mlp/act_fn/Add_1_output_0 + - /language_model/layers.18/mlp/act_fn/Add_output_0 + - /language_model/layers.18/mlp/act_fn/Add_1_output_0 + - /language_model/layers.19/mlp/act_fn/Add_output_0 + - /language_model/layers.19/mlp/act_fn/Add_1_output_0 + - /language_model/layers.20/mlp/act_fn/Add_output_0 + - /language_model/layers.20/mlp/act_fn/Add_1_output_0 + - /language_model/layers.21/mlp/act_fn/Add_output_0 + - /language_model/layers.21/mlp/act_fn/Add_1_output_0 + - /language_model/layers.22/mlp/act_fn/Add_output_0 + - /language_model/layers.22/mlp/act_fn/Add_1_output_0 + - /language_model/layers.23/mlp/act_fn/Add_output_0 + - /language_model/layers.23/mlp/act_fn/Add_1_output_0 + - /language_model/layers.24/mlp/act_fn/Add_output_0 + - /language_model/layers.24/mlp/act_fn/Add_1_output_0 + - /language_model/layers.25/mlp/act_fn/Add_output_0 + - /language_model/layers.25/mlp/act_fn/Add_1_output_0 + - /language_model/layers.26/mlp/act_fn/Add_output_0 + - /language_model/layers.26/mlp/act_fn/Add_1_output_0 + - /language_model/layers.27/mlp/act_fn/Add_output_0 + - /language_model/layers.27/mlp/act_fn/Add_1_output_0 + - /language_model/layers.28/mlp/act_fn/Add_output_0 + - /language_model/layers.28/mlp/act_fn/Add_1_output_0 + - /language_model/layers.29/mlp/act_fn/Add_output_0 + - /language_model/layers.29/mlp/act_fn/Add_1_output_0 + - /language_model/layers.30/mlp/act_fn/Add_output_0 + - /language_model/layers.30/mlp/act_fn/Add_1_output_0 + - /language_model/layers.31/mlp/act_fn/Add_output_0 + - /language_model/layers.31/mlp/act_fn/Add_1_output_0 + - /language_model/layers.32/mlp/act_fn/Add_output_0 + - /language_model/layers.32/mlp/act_fn/Add_1_output_0 + - /language_model/layers.33/mlp/act_fn/Add_output_0 + - /language_model/layers.33/mlp/act_fn/Add_1_output_0 + - /language_model/layers.34/mlp/act_fn/Add_output_0 + - /language_model/layers.34/mlp/act_fn/Add_1_output_0 + - /language_model/layers.35/mlp/act_fn/Add_output_0 + - /language_model/layers.35/mlp/act_fn/Add_1_output_0 + - /language_model/layers.36/mlp/act_fn/Add_output_0 + - /language_model/layers.36/mlp/act_fn/Add_1_output_0 + - /language_model/layers.37/mlp/act_fn/Add_output_0 + - /language_model/layers.37/mlp/act_fn/Add_1_output_0 + - /language_model/layers.38/mlp/act_fn/Add_output_0 + - /language_model/layers.38/mlp/act_fn/Add_1_output_0 + - /language_model/layers.39/mlp/act_fn/Add_output_0 + - /language_model/layers.39/mlp/act_fn/Add_1_output_0 + - /language_model/layers.40/mlp/act_fn/Add_output_0 + - /language_model/layers.40/mlp/act_fn/Add_1_output_0 + - /language_model/layers.41/mlp/act_fn/Add_output_0 + - /language_model/layers.41/mlp/act_fn/Add_1_output_0 + - /language_model/layers.42/mlp/act_fn/Add_output_0 + - /language_model/layers.42/mlp/act_fn/Add_1_output_0 + - /language_model/layers.43/mlp/act_fn/Add_output_0 + - /language_model/layers.43/mlp/act_fn/Add_1_output_0 + - /language_model/layers.44/mlp/act_fn/Add_output_0 + - /language_model/layers.44/mlp/act_fn/Add_1_output_0 + - /language_model/layers.45/mlp/act_fn/Add_output_0 + - /language_model/layers.45/mlp/act_fn/Add_1_output_0 + - /language_model/layers.46/mlp/act_fn/Add_output_0 + - /language_model/layers.46/mlp/act_fn/Add_1_output_0 + - /language_model/layers.47/mlp/act_fn/Add_output_0 + - /language_model/layers.47/mlp/act_fn/Add_1_output_0 + - /language_model/layers.48/mlp/act_fn/Add_output_0 + - /language_model/layers.48/mlp/act_fn/Add_1_output_0 + - /language_model/layers.49/mlp/act_fn/Add_output_0 + - /language_model/layers.49/mlp/act_fn/Add_1_output_0 + - /language_model/layers.50/mlp/act_fn/Add_output_0 + - /language_model/layers.50/mlp/act_fn/Add_1_output_0 + - /language_model/layers.51/mlp/act_fn/Add_output_0 + - /language_model/layers.51/mlp/act_fn/Add_1_output_0 + - /language_model/layers.52/mlp/act_fn/Add_output_0 + - /language_model/layers.52/mlp/act_fn/Add_1_output_0 + - /language_model/layers.53/mlp/act_fn/Add_output_0 + - /language_model/layers.53/mlp/act_fn/Add_1_output_0 + - /language_model/layers.54/mlp/act_fn/Add_output_0 + - /language_model/layers.54/mlp/act_fn/Add_1_output_0 + - /language_model/layers.55/mlp/act_fn/Add_output_0 + - /language_model/layers.55/mlp/act_fn/Add_1_output_0 + - /language_model/layers.56/mlp/act_fn/Add_output_0 + - /language_model/layers.56/mlp/act_fn/Add_1_output_0 + - /language_model/layers.57/mlp/act_fn/Add_output_0 + - /language_model/layers.57/mlp/act_fn/Add_1_output_0 + - /language_model/layers.58/mlp/act_fn/Add_output_0 + - /language_model/layers.58/mlp/act_fn/Add_1_output_0 + - /language_model/layers.59/mlp/act_fn/Add_output_0 + - /language_model/layers.59/mlp/act_fn/Add_1_output_0 + - /language_model/layers.60/mlp/act_fn/Add_output_0 + - /language_model/layers.60/mlp/act_fn/Add_1_output_0 + - /language_model/layers.61/mlp/act_fn/Add_output_0 + - /language_model/layers.61/mlp/act_fn/Add_1_output_0 + + #Tanh + - /language_model/layers.0/mlp/act_fn/Tanh_output_0 + - /language_model/layers.1/mlp/act_fn/Tanh_output_0 + - /language_model/layers.2/mlp/act_fn/Tanh_output_0 + - /language_model/layers.3/mlp/act_fn/Tanh_output_0 + - /language_model/layers.4/mlp/act_fn/Tanh_output_0 + - /language_model/layers.5/mlp/act_fn/Tanh_output_0 + - /language_model/layers.6/mlp/act_fn/Tanh_output_0 + - /language_model/layers.7/mlp/act_fn/Tanh_output_0 + - /language_model/layers.8/mlp/act_fn/Tanh_output_0 + - /language_model/layers.9/mlp/act_fn/Tanh_output_0 + - /language_model/layers.10/mlp/act_fn/Tanh_output_0 + - /language_model/layers.11/mlp/act_fn/Tanh_output_0 + - /language_model/layers.12/mlp/act_fn/Tanh_output_0 + - /language_model/layers.13/mlp/act_fn/Tanh_output_0 + - /language_model/layers.14/mlp/act_fn/Tanh_output_0 + - /language_model/layers.15/mlp/act_fn/Tanh_output_0 + - /language_model/layers.16/mlp/act_fn/Tanh_output_0 + - /language_model/layers.17/mlp/act_fn/Tanh_output_0 + - /language_model/layers.18/mlp/act_fn/Tanh_output_0 + - /language_model/layers.19/mlp/act_fn/Tanh_output_0 + - /language_model/layers.20/mlp/act_fn/Tanh_output_0 + - /language_model/layers.21/mlp/act_fn/Tanh_output_0 + - /language_model/layers.22/mlp/act_fn/Tanh_output_0 + - /language_model/layers.23/mlp/act_fn/Tanh_output_0 + - /language_model/layers.24/mlp/act_fn/Tanh_output_0 + - /language_model/layers.25/mlp/act_fn/Tanh_output_0 + - /language_model/layers.26/mlp/act_fn/Tanh_output_0 + - /language_model/layers.27/mlp/act_fn/Tanh_output_0 + - /language_model/layers.28/mlp/act_fn/Tanh_output_0 + - /language_model/layers.29/mlp/act_fn/Tanh_output_0 + - /language_model/layers.30/mlp/act_fn/Tanh_output_0 + - /language_model/layers.31/mlp/act_fn/Tanh_output_0 + - /language_model/layers.32/mlp/act_fn/Tanh_output_0 + - /language_model/layers.33/mlp/act_fn/Tanh_output_0 + - /language_model/layers.34/mlp/act_fn/Tanh_output_0 + - /language_model/layers.35/mlp/act_fn/Tanh_output_0 + - /language_model/layers.36/mlp/act_fn/Tanh_output_0 + - /language_model/layers.37/mlp/act_fn/Tanh_output_0 + - /language_model/layers.38/mlp/act_fn/Tanh_output_0 + - /language_model/layers.39/mlp/act_fn/Tanh_output_0 + - /language_model/layers.40/mlp/act_fn/Tanh_output_0 + - /language_model/layers.41/mlp/act_fn/Tanh_output_0 + - /language_model/layers.42/mlp/act_fn/Tanh_output_0 + - /language_model/layers.43/mlp/act_fn/Tanh_output_0 + - /language_model/layers.44/mlp/act_fn/Tanh_output_0 + - /language_model/layers.45/mlp/act_fn/Tanh_output_0 + - /language_model/layers.46/mlp/act_fn/Tanh_output_0 + - /language_model/layers.47/mlp/act_fn/Tanh_output_0 + - /language_model/layers.48/mlp/act_fn/Tanh_output_0 + - /language_model/layers.49/mlp/act_fn/Tanh_output_0 + - /language_model/layers.50/mlp/act_fn/Tanh_output_0 + - /language_model/layers.51/mlp/act_fn/Tanh_output_0 + - /language_model/layers.52/mlp/act_fn/Tanh_output_0 + - /language_model/layers.53/mlp/act_fn/Tanh_output_0 + - /language_model/layers.54/mlp/act_fn/Tanh_output_0 + - /language_model/layers.55/mlp/act_fn/Tanh_output_0 + - /language_model/layers.56/mlp/act_fn/Tanh_output_0 + - /language_model/layers.57/mlp/act_fn/Tanh_output_0 + - /language_model/layers.58/mlp/act_fn/Tanh_output_0 + - /language_model/layers.59/mlp/act_fn/Tanh_output_0 + - /language_model/layers.60/mlp/act_fn/Tanh_output_0 + - /language_model/layers.61/mlp/act_fn/Tanh_output_0 + - /language_model/layers.0/mlp/Mul_output_0 + - /language_model/layers.1/mlp/Mul_output_0 + - /language_model/layers.2/mlp/Mul_output_0 + - /language_model/layers.3/mlp/Mul_output_0 + - /language_model/layers.4/mlp/Mul_output_0 + - /language_model/layers.5/mlp/Mul_output_0 + - /language_model/layers.6/mlp/Mul_output_0 + - /language_model/layers.7/mlp/Mul_output_0 + - /language_model/layers.8/mlp/Mul_output_0 + - /language_model/layers.9/mlp/Mul_output_0 + - /language_model/layers.10/mlp/Mul_output_0 + - /language_model/layers.11/mlp/Mul_output_0 + - /language_model/layers.12/mlp/Mul_output_0 + - /language_model/layers.13/mlp/Mul_output_0 + - /language_model/layers.14/mlp/Mul_output_0 + - /language_model/layers.15/mlp/Mul_output_0 + - /language_model/layers.16/mlp/Mul_output_0 + - /language_model/layers.17/mlp/Mul_output_0 + - /language_model/layers.18/mlp/Mul_output_0 + - /language_model/layers.19/mlp/Mul_output_0 + - /language_model/layers.20/mlp/Mul_output_0 + - /language_model/layers.21/mlp/Mul_output_0 + - /language_model/layers.22/mlp/Mul_output_0 + - /language_model/layers.23/mlp/Mul_output_0 + - /language_model/layers.24/mlp/Mul_output_0 + - /language_model/layers.25/mlp/Mul_output_0 + - /language_model/layers.26/mlp/Mul_output_0 + - /language_model/layers.27/mlp/Mul_output_0 + - /language_model/layers.28/mlp/Mul_output_0 + - /language_model/layers.29/mlp/Mul_output_0 + - /language_model/layers.30/mlp/Mul_output_0 + - /language_model/layers.31/mlp/Mul_output_0 + - /language_model/layers.32/mlp/Mul_output_0 + - /language_model/layers.33/mlp/Mul_output_0 + - /language_model/layers.34/mlp/Mul_output_0 + - /language_model/layers.35/mlp/Mul_output_0 + - /language_model/layers.36/mlp/Mul_output_0 + - /language_model/layers.37/mlp/Mul_output_0 + - /language_model/layers.38/mlp/Mul_output_0 + - /language_model/layers.39/mlp/Mul_output_0 + - /language_model/layers.40/mlp/Mul_output_0 + - /language_model/layers.41/mlp/Mul_output_0 + - /language_model/layers.42/mlp/Mul_output_0 + - /language_model/layers.43/mlp/Mul_output_0 + - /language_model/layers.44/mlp/Mul_output_0 + - /language_model/layers.45/mlp/Mul_output_0 + - /language_model/layers.46/mlp/Mul_output_0 + - /language_model/layers.47/mlp/Mul_output_0 + - /language_model/layers.48/mlp/Mul_output_0 + - /language_model/layers.49/mlp/Mul_output_0 + - /language_model/layers.50/mlp/Mul_output_0 + - /language_model/layers.51/mlp/Mul_output_0 + - /language_model/layers.52/mlp/Mul_output_0 + - /language_model/layers.53/mlp/Mul_output_0 + - /language_model/layers.54/mlp/Mul_output_0 + - /language_model/layers.55/mlp/Mul_output_0 + - /language_model/layers.56/mlp/Mul_output_0 + - /language_model/layers.57/mlp/Mul_output_0 + - /language_model/layers.58/mlp/Mul_output_0 + - /language_model/layers.59/mlp/Mul_output_0 + - /language_model/layers.60/mlp/Mul_output_0 + - /language_model/layers.61/mlp/Mul_output_0 + - /language_model/layers.0/Add_1_output_0 + - /language_model/layers.0/Add_2_output_0 + - /language_model/layers.0/Add_3_output_0 + - /language_model/layers.0/Add_output_0 + - /language_model/layers.1/Add_1_output_0 + - /language_model/layers.1/Add_2_output_0 + - /language_model/layers.1/Add_3_output_0 + - /language_model/layers.1/Add_output_0 + - /language_model/layers.2/Add_1_output_0 + - /language_model/layers.2/Add_2_output_0 + - /language_model/layers.2/Add_3_output_0 + - /language_model/layers.2/Add_output_0 + - /language_model/layers.3/Add_1_output_0 + - /language_model/layers.3/Add_2_output_0 + - /language_model/layers.3/Add_3_output_0 + - /language_model/layers.3/Add_output_0 + - /language_model/layers.4/Add_1_output_0 + - /language_model/layers.4/Add_2_output_0 + - /language_model/layers.4/Add_3_output_0 + - /language_model/layers.4/Add_output_0 + - /language_model/layers.5/Add_1_output_0 + - /language_model/layers.5/Add_2_output_0 + - /language_model/layers.5/Add_3_output_0 + - /language_model/layers.5/Add_output_0 + - /language_model/layers.6/Add_1_output_0 + - /language_model/layers.6/Add_2_output_0 + - /language_model/layers.6/Add_3_output_0 + - /language_model/layers.6/Add_output_0 + - /language_model/layers.7/Add_1_output_0 + - /language_model/layers.7/Add_2_output_0 + - /language_model/layers.7/Add_3_output_0 + - /language_model/layers.7/Add_output_0 + - /language_model/layers.8/Add_1_output_0 + - /language_model/layers.8/Add_2_output_0 + - /language_model/layers.8/Add_3_output_0 + - /language_model/layers.8/Add_output_0 + - /language_model/layers.9/Add_1_output_0 + - /language_model/layers.9/Add_2_output_0 + - /language_model/layers.9/Add_3_output_0 + - /language_model/layers.9/Add_output_0 + - /language_model/layers.10/Add_1_output_0 + - /language_model/layers.10/Add_2_output_0 + - /language_model/layers.10/Add_3_output_0 + - /language_model/layers.10/Add_output_0 + - /language_model/layers.11/Add_1_output_0 + - /language_model/layers.11/Add_2_output_0 + - /language_model/layers.11/Add_3_output_0 + - /language_model/layers.11/Add_output_0 + - /language_model/layers.12/Add_1_output_0 + - /language_model/layers.12/Add_2_output_0 + - /language_model/layers.12/Add_3_output_0 + - /language_model/layers.12/Add_output_0 + - /language_model/layers.13/Add_1_output_0 + - /language_model/layers.13/Add_2_output_0 + - /language_model/layers.13/Add_3_output_0 + - /language_model/layers.13/Add_output_0 + - /language_model/layers.14/Add_1_output_0 + - /language_model/layers.14/Add_2_output_0 + - /language_model/layers.14/Add_3_output_0 + - /language_model/layers.14/Add_output_0 + - /language_model/layers.15/Add_1_output_0 + - /language_model/layers.15/Add_2_output_0 + - /language_model/layers.15/Add_3_output_0 + - /language_model/layers.15/Add_output_0 + - /language_model/layers.16/Add_1_output_0 + - /language_model/layers.16/Add_2_output_0 + - /language_model/layers.16/Add_3_output_0 + - /language_model/layers.16/Add_output_0 + - /language_model/layers.17/Add_1_output_0 + - /language_model/layers.17/Add_2_output_0 + - /language_model/layers.17/Add_3_output_0 + - /language_model/layers.17/Add_output_0 + - /language_model/layers.18/Add_1_output_0 + - /language_model/layers.18/Add_2_output_0 + - /language_model/layers.18/Add_3_output_0 + - /language_model/layers.18/Add_output_0 + - /language_model/layers.19/Add_1_output_0 + - /language_model/layers.19/Add_2_output_0 + - /language_model/layers.19/Add_3_output_0 + - /language_model/layers.19/Add_output_0 + - /language_model/layers.20/Add_1_output_0 + - /language_model/layers.20/Add_2_output_0 + - /language_model/layers.20/Add_3_output_0 + - /language_model/layers.20/Add_output_0 + - /language_model/layers.21/Add_1_output_0 + - /language_model/layers.21/Add_2_output_0 + - /language_model/layers.21/Add_3_output_0 + - /language_model/layers.21/Add_output_0 + - /language_model/layers.22/Add_1_output_0 + - /language_model/layers.22/Add_2_output_0 + - /language_model/layers.22/Add_3_output_0 + - /language_model/layers.22/Add_output_0 + - /language_model/layers.23/Add_1_output_0 + - /language_model/layers.23/Add_2_output_0 + - /language_model/layers.23/Add_output_0 + - /language_model/layers.24/Add_1_output_0 + - /language_model/layers.24/Add_2_output_0 + - /language_model/layers.24/Add_3_output_0 + - /language_model/layers.24/Add_output_0 + - /language_model/layers.25/Add_1_output_0 + - /language_model/layers.25/Add_2_output_0 + - /language_model/layers.25/Add_3_output_0 + - /language_model/layers.25/Add_output_0 + - /language_model/layers.26/Add_1_output_0 + - /language_model/layers.26/Add_2_output_0 + - /language_model/layers.26/Add_3_output_0 + - /language_model/layers.26/Add_output_0 + - /language_model/layers.27/Add_1_output_0 + - /language_model/layers.27/Add_2_output_0 + - /language_model/layers.27/Add_3_output_0 + - /language_model/layers.27/Add_output_0 + - /language_model/layers.28/Add_1_output_0 + - /language_model/layers.28/Add_2_output_0 + - /language_model/layers.28/Add_3_output_0 + - /language_model/layers.28/Add_output_0 + - /language_model/layers.29/Add_1_output_0 + - /language_model/layers.29/Add_2_output_0 + - /language_model/layers.29/Add_3_output_0 + - /language_model/layers.29/Add_output_0 + - /language_model/layers.30/Add_1_output_0 + - /language_model/layers.30/Add_2_output_0 + - /language_model/layers.30/Add_3_output_0 + - /language_model/layers.30/Add_output_0 + - /language_model/layers.31/Add_1_output_0 + - /language_model/layers.31/Add_2_output_0 + - /language_model/layers.31/Add_3_output_0 + - /language_model/layers.31/Add_output_0 + - /language_model/layers.32/Add_1_output_0 + - /language_model/layers.32/Add_2_output_0 + - /language_model/layers.32/Add_3_output_0 + - /language_model/layers.32/Add_output_0 + - /language_model/layers.33/Add_1_output_0 + - /language_model/layers.33/Add_2_output_0 + - /language_model/layers.33/Add_3_output_0 + - /language_model/layers.33/Add_output_0 + - /language_model/layers.34/Add_1_output_0 + - /language_model/layers.34/Add_2_output_0 + - /language_model/layers.34/Add_3_output_0 + - /language_model/layers.34/Add_output_0 + - /language_model/layers.35/Add_1_output_0 + - /language_model/layers.35/Add_2_output_0 + - /language_model/layers.35/Add_3_output_0 + - /language_model/layers.35/Add_output_0 + - /language_model/layers.36/Add_1_output_0 + - /language_model/layers.36/Add_2_output_0 + - /language_model/layers.36/Add_3_output_0 + - /language_model/layers.36/Add_output_0 + - /language_model/layers.37/Add_1_output_0 + - /language_model/layers.37/Add_2_output_0 + - /language_model/layers.37/Add_3_output_0 + - /language_model/layers.37/Add_output_0 + - /language_model/layers.38/Add_1_output_0 + - /language_model/layers.38/Add_2_output_0 + - /language_model/layers.38/Add_3_output_0 + - /language_model/layers.38/Add_output_0 + - /language_model/layers.39/Add_1_output_0 + - /language_model/layers.39/Add_2_output_0 + - /language_model/layers.39/Add_3_output_0 + - /language_model/layers.39/Add_output_0 + - /language_model/layers.40/Add_1_output_0 + - /language_model/layers.40/Add_2_output_0 + - /language_model/layers.40/Add_3_output_0 + - /language_model/layers.40/Add_output_0 + - /language_model/layers.41/Add_1_output_0 + - /language_model/layers.41/Add_2_output_0 + - /language_model/layers.41/Add_3_output_0 + - /language_model/layers.41/Add_output_0 + - /language_model/layers.42/Add_1_output_0 + - /language_model/layers.42/Add_2_output_0 + - /language_model/layers.42/Add_3_output_0 + - /language_model/layers.42/Add_output_0 + - /language_model/layers.43/Add_1_output_0 + - /language_model/layers.43/Add_2_output_0 + - /language_model/layers.43/Add_3_output_0 + - /language_model/layers.43/Add_output_0 + - /language_model/layers.44/Add_1_output_0 + - /language_model/layers.44/Add_2_output_0 + - /language_model/layers.44/Add_3_output_0 + - /language_model/layers.44/Add_output_0 + - /language_model/layers.45/Add_1_output_0 + - /language_model/layers.45/Add_2_output_0 + - /language_model/layers.45/Add_3_output_0 + - /language_model/layers.45/Add_output_0 + - /language_model/layers.46/Add_1_output_0 + - /language_model/layers.46/Add_2_output_0 + - /language_model/layers.46/Add_3_output_0 + - /language_model/layers.46/Add_output_0 + - /language_model/layers.47/Add_1_output_0 + - /language_model/layers.47/Add_2_output_0 + - /language_model/layers.47/Add_3_output_0 + - /language_model/layers.47/Add_output_0 + - /language_model/layers.48/Add_1_output_0 + - /language_model/layers.48/Add_2_output_0 + - /language_model/layers.48/Add_3_output_0 + - /language_model/layers.48/Add_output_0 + - /language_model/layers.49/Add_1_output_0 + - /language_model/layers.49/Add_2_output_0 + - /language_model/layers.49/Add_3_output_0 + - /language_model/layers.49/Add_output_0 + - /language_model/layers.50/Add_1_output_0 + - /language_model/layers.50/Add_2_output_0 + - /language_model/layers.50/Add_3_output_0 + - /language_model/layers.50/Add_output_0 + - /language_model/layers.51/Add_1_output_0 + - /language_model/layers.51/Add_2_output_0 + - /language_model/layers.51/Add_3_output_0 + - /language_model/layers.51/Add_output_0 + - /language_model/layers.52/Add_1_output_0 + - /language_model/layers.52/Add_2_output_0 + - /language_model/layers.52/Add_3_output_0 + - /language_model/layers.52/Add_output_0 + - /language_model/layers.53/Add_1_output_0 + - /language_model/layers.53/Add_2_output_0 + - /language_model/layers.53/Add_3_output_0 + - /language_model/layers.53/Add_output_0 + - /language_model/layers.54/Add_1_output_0 + - /language_model/layers.54/Add_2_output_0 + - /language_model/layers.54/Add_3_output_0 + - /language_model/layers.54/Add_output_0 + - /language_model/layers.55/Add_1_output_0 + - /language_model/layers.55/Add_2_output_0 + - /language_model/layers.55/Add_3_output_0 + - /language_model/layers.55/Add_output_0 + - /language_model/layers.56/Add_1_output_0 + - /language_model/layers.56/Add_2_output_0 + - /language_model/layers.56/Add_3_output_0 + - /language_model/layers.56/Add_output_0 + - /language_model/layers.57/Add_1_output_0 + - /language_model/layers.57/Add_2_output_0 + - /language_model/layers.57/Add_3_output_0 + - /language_model/layers.57/Add_output_0 + - /language_model/layers.58/Add_1_output_0 + - /language_model/layers.58/Add_2_output_0 + - /language_model/layers.58/Add_3_output_0 + - /language_model/layers.58/Add_output_0 + - /language_model/layers.59/Add_1_output_0 + - /language_model/layers.59/Add_2_output_0 + - /language_model/layers.59/Add_3_output_0 + - /language_model/layers.59/Add_output_0 + - /language_model/layers.60/Add_1_output_0 + - /language_model/layers.60/Add_2_output_0 + - /language_model/layers.60/Add_3_output_0 + - /language_model/layers.60/Add_output_0 + - /language_model/layers.61/Add_1_output_0 + - /language_model/layers.61/Add_2_output_0 + - /language_model/layers.61/Add_3_output_0 + - /language_model/layers.61/Add_output_0 + - /language_model/norm/Add_output_0 + - /language_model/layers.0/self_attn/Mul_output_0 + - /language_model/layers.2/self_attn/Mul_output_0 + - /language_model/layers.3/self_attn/Mul_output_0 + - /language_model/layers.4/self_attn/Mul_output_0 + - /language_model/layers.5/self_attn/Mul_output_0 + - /language_model/layers.6/self_attn/Mul_output_0 + - /language_model/layers.7/self_attn/Mul_output_0 + - /language_model/layers.8/self_attn/Mul_output_0 + - /language_model/layers.9/self_attn/Mul_output_0 + - /language_model/layers.10/self_attn/Mul_output_0 + - /language_model/layers.11/self_attn/Mul_output_0 + - /language_model/layers.12/self_attn/Mul_output_0 + - /language_model/layers.13/self_attn/Mul_output_0 + - /language_model/layers.14/self_attn/Mul_output_0 + - /language_model/layers.15/self_attn/Mul_output_0 + - /language_model/layers.16/self_attn/Mul_output_0 + - /language_model/layers.17/self_attn/Mul_output_0 + - /language_model/layers.18/self_attn/Mul_output_0 + - /language_model/layers.19/self_attn/Mul_output_0 + - /language_model/layers.20/self_attn/Mul_output_0 + - /language_model/layers.21/self_attn/Mul_output_0 + - /language_model/layers.22/self_attn/Mul_output_0 + - /language_model/layers.23/self_attn/Mul_output_0 + - /language_model/layers.24/self_attn/Mul_output_0 + - /language_model/layers.25/self_attn/Mul_output_0 + - /language_model/layers.26/self_attn/Mul_output_0 + - /language_model/layers.27/self_attn/Mul_output_0 + - /language_model/layers.28/self_attn/Mul_output_0 + - /language_model/layers.29/self_attn/Mul_output_0 + - /language_model/layers.30/self_attn/Mul_output_0 + - /language_model/layers.31/self_attn/Mul_output_0 + - /language_model/layers.32/self_attn/Mul_output_0 + - /language_model/layers.33/self_attn/Mul_output_0 + - /language_model/layers.34/self_attn/Mul_output_0 + - /language_model/layers.35/self_attn/Mul_output_0 + - /language_model/layers.36/self_attn/Mul_output_0 + - /language_model/layers.37/self_attn/Mul_output_0 + - /language_model/layers.38/self_attn/Mul_output_0 + - /language_model/layers.39/self_attn/Mul_output_0 + - /language_model/layers.40/self_attn/Mul_output_0 + - /language_model/layers.41/self_attn/Mul_output_0 + - /language_model/layers.42/self_attn/Mul_output_0 + - /language_model/layers.43/self_attn/Mul_output_0 + - /language_model/layers.44/self_attn/Mul_output_0 + - /language_model/layers.45/self_attn/Mul_output_0 + - /language_model/layers.46/self_attn/Mul_output_0 + - /language_model/layers.47/self_attn/Mul_output_0 + - /language_model/layers.48/self_attn/Mul_output_0 + - /language_model/layers.49/self_attn/Mul_output_0 + - /language_model/layers.50/self_attn/Mul_output_0 + - /language_model/layers.51/self_attn/Mul_output_0 + - /language_model/layers.52/self_attn/Mul_output_0 + - /language_model/layers.53/self_attn/Mul_output_0 + - /language_model/layers.54/self_attn/Mul_output_0 + - /language_model/layers.55/self_attn/Mul_output_0 + - /language_model/layers.56/self_attn/Mul_output_0 + - /language_model/layers.57/self_attn/Mul_output_0 + - /language_model/layers.58/self_attn/Mul_output_0 + - /language_model/layers.59/self_attn/Mul_output_0 + - /language_model/layers.60/self_attn/Mul_output_0 + - /language_model/layers.61/self_attn/Mul_output_0 + - /language_model/layers.0/input_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.0/post_attention_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.0/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.0/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.0/self_attn/k_norm/CustomRMSNorm_output_0 + - /language_model/layers.0/self_attn/q_norm/CustomRMSNorm_output_0 + - /language_model/layers.1/input_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.1/post_attention_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.1/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.1/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.1/self_attn/k_norm/CustomRMSNorm_output_0 + - /language_model/layers.1/self_attn/q_norm/CustomRMSNorm_output_0 + - /language_model/layers.2/input_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.2/post_attention_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.2/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.2/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.2/self_attn/k_norm/CustomRMSNorm_output_0 + - /language_model/layers.2/self_attn/q_norm/CustomRMSNorm_output_0 + - /language_model/layers.3/input_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.3/post_attention_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.3/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.3/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.3/self_attn/k_norm/CustomRMSNorm_output_0 + - /language_model/layers.3/self_attn/q_norm/CustomRMSNorm_output_0 + - /language_model/layers.4/input_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.4/post_attention_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.4/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.4/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.4/self_attn/k_norm/CustomRMSNorm_output_0 + - /language_model/layers.4/self_attn/q_norm/CustomRMSNorm_output_0 + - /language_model/layers.5/input_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.5/post_attention_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.5/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.5/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.5/self_attn/k_norm/CustomRMSNorm_output_0 + - /language_model/layers.5/self_attn/q_norm/CustomRMSNorm_output_0 + - /language_model/layers.6/input_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.6/post_attention_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.6/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.6/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.6/self_attn/k_norm/CustomRMSNorm_output_0 + - /language_model/layers.6/self_attn/q_norm/CustomRMSNorm_output_0 + - /language_model/layers.7/input_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.7/post_attention_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.7/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.7/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.7/self_attn/k_norm/CustomRMSNorm_output_0 + - /language_model/layers.7/self_attn/q_norm/CustomRMSNorm_output_0 + - /language_model/layers.8/input_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.8/post_attention_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.8/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.8/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.8/self_attn/k_norm/CustomRMSNorm_output_0 + - /language_model/layers.8/self_attn/q_norm/CustomRMSNorm_output_0 + - /language_model/layers.9/input_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.9/post_attention_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.9/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.9/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.9/self_attn/k_norm/CustomRMSNorm_output_0 + - /language_model/layers.9/self_attn/q_norm/CustomRMSNorm_output_0 + - /language_model/layers.10/input_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.10/post_attention_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.10/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.10/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.10/self_attn/k_norm/CustomRMSNorm_output_0 + - /language_model/layers.10/self_attn/q_norm/CustomRMSNorm_output_0 + - /language_model/layers.11/input_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.11/post_attention_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.11/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.11/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.11/self_attn/k_norm/CustomRMSNorm_output_0 + - /language_model/layers.11/self_attn/q_norm/CustomRMSNorm_output_0 + - /language_model/layers.12/input_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.12/post_attention_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.12/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.12/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.12/self_attn/k_norm/CustomRMSNorm_output_0 + - /language_model/layers.12/self_attn/q_norm/CustomRMSNorm_output_0 + - /language_model/layers.13/input_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.13/post_attention_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.13/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.13/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.13/self_attn/k_norm/CustomRMSNorm_output_0 + - /language_model/layers.13/self_attn/q_norm/CustomRMSNorm_output_0 + - /language_model/layers.14/input_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.14/post_attention_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.14/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.14/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.14/self_attn/k_norm/CustomRMSNorm_output_0 + - /language_model/layers.14/self_attn/q_norm/CustomRMSNorm_output_0 + - /language_model/layers.15/input_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.15/post_attention_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.15/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.15/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.15/self_attn/k_norm/CustomRMSNorm_output_0 + - /language_model/layers.15/self_attn/q_norm/CustomRMSNorm_output_0 + - /language_model/layers.16/input_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.16/post_attention_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.16/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.16/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.16/self_attn/k_norm/CustomRMSNorm_output_0 + - /language_model/layers.16/self_attn/q_norm/CustomRMSNorm_output_0 + - /language_model/layers.17/input_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.17/post_attention_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.17/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.17/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.17/self_attn/k_norm/CustomRMSNorm_output_0 + - /language_model/layers.17/self_attn/q_norm/CustomRMSNorm_output_0 + - /language_model/layers.18/input_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.18/post_attention_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.18/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.18/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.18/self_attn/k_norm/CustomRMSNorm_output_0 + - /language_model/layers.18/self_attn/q_norm/CustomRMSNorm_output_0 + - /language_model/layers.19/input_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.19/post_attention_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.19/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.19/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.19/self_attn/k_norm/CustomRMSNorm_output_0 + - /language_model/layers.19/self_attn/q_norm/CustomRMSNorm_output_0 + - /language_model/layers.20/input_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.20/post_attention_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.20/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.20/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.20/self_attn/k_norm/CustomRMSNorm_output_0 + - /language_model/layers.20/self_attn/q_norm/CustomRMSNorm_output_0 + - /language_model/layers.21/input_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.21/post_attention_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.21/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.21/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.21/self_attn/k_norm/CustomRMSNorm_output_0 + - /language_model/layers.21/self_attn/q_norm/CustomRMSNorm_output_0 + - /language_model/layers.22/input_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.22/post_attention_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.22/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.22/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.22/self_attn/k_norm/CustomRMSNorm_output_0 + - /language_model/layers.22/self_attn/q_norm/CustomRMSNorm_output_0 + - /language_model/layers.23/input_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.23/post_attention_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.23/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.23/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.23/self_attn/k_norm/CustomRMSNorm_output_0 + - /language_model/layers.23/self_attn/q_norm/CustomRMSNorm_output_0 + - /language_model/layers.24/input_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.24/post_attention_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.24/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.24/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.24/self_attn/k_norm/CustomRMSNorm_output_0 + - /language_model/layers.24/self_attn/q_norm/CustomRMSNorm_output_0 + - /language_model/layers.25/input_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.25/post_attention_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.25/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.25/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.25/self_attn/k_norm/CustomRMSNorm_output_0 + - /language_model/layers.25/self_attn/q_norm/CustomRMSNorm_output_0 + - /language_model/layers.26/input_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.26/post_attention_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.26/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.26/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.26/self_attn/k_norm/CustomRMSNorm_output_0 + - /language_model/layers.26/self_attn/q_norm/CustomRMSNorm_output_0 + - /language_model/layers.27/input_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.27/post_attention_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.27/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.27/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.27/self_attn/k_norm/CustomRMSNorm_output_0 + - /language_model/layers.27/self_attn/q_norm/CustomRMSNorm_output_0 + - /language_model/layers.28/input_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.28/post_attention_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.28/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.28/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.28/self_attn/k_norm/CustomRMSNorm_output_0 + - /language_model/layers.28/self_attn/q_norm/CustomRMSNorm_output_0 + - /language_model/layers.29/input_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.29/post_attention_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.29/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.29/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.29/self_attn/k_norm/CustomRMSNorm_output_0 + - /language_model/layers.29/self_attn/q_norm/CustomRMSNorm_output_0 + - /language_model/layers.30/input_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.30/post_attention_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.30/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.30/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.30/self_attn/k_norm/CustomRMSNorm_output_0 + - /language_model/layers.30/self_attn/q_norm/CustomRMSNorm_output_0 + - /language_model/layers.31/input_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.31/post_attention_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.31/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.31/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.31/self_attn/k_norm/CustomRMSNorm_output_0 + - /language_model/layers.31/self_attn/q_norm/CustomRMSNorm_output_0 + - /language_model/layers.32/input_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.32/post_attention_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.32/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.32/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.32/self_attn/k_norm/CustomRMSNorm_output_0 + - /language_model/layers.32/self_attn/q_norm/CustomRMSNorm_output_0 + - /language_model/layers.33/input_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.33/post_attention_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.33/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.33/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.33/self_attn/k_norm/CustomRMSNorm_output_0 + - /language_model/layers.33/self_attn/q_norm/CustomRMSNorm_output_0 + - /language_model/layers.34/input_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.34/post_attention_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.34/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.34/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.34/self_attn/k_norm/CustomRMSNorm_output_0 + - /language_model/layers.34/self_attn/q_norm/CustomRMSNorm_output_0 + - /language_model/layers.35/input_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.35/post_attention_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.35/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.35/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.35/self_attn/k_norm/CustomRMSNorm_output_0 + - /language_model/layers.35/self_attn/q_norm/CustomRMSNorm_output_0 + - /language_model/layers.36/input_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.36/post_attention_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.36/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.36/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.36/self_attn/k_norm/CustomRMSNorm_output_0 + - /language_model/layers.36/self_attn/q_norm/CustomRMSNorm_output_0 + - /language_model/layers.37/input_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.37/post_attention_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.37/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.37/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.37/self_attn/k_norm/CustomRMSNorm_output_0 + - /language_model/layers.37/self_attn/q_norm/CustomRMSNorm_output_0 + - /language_model/layers.38/input_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.38/post_attention_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.38/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.38/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.38/self_attn/k_norm/CustomRMSNorm_output_0 + - /language_model/layers.38/self_attn/q_norm/CustomRMSNorm_output_0 + - /language_model/layers.39/input_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.39/post_attention_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.39/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.39/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.39/self_attn/k_norm/CustomRMSNorm_output_0 + - /language_model/layers.39/self_attn/q_norm/CustomRMSNorm_output_0 + - /language_model/layers.40/input_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.40/post_attention_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.40/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.40/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.40/self_attn/k_norm/CustomRMSNorm_output_0 + - /language_model/layers.40/self_attn/q_norm/CustomRMSNorm_output_0 + - /language_model/layers.41/input_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.41/post_attention_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.41/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.41/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.41/self_attn/k_norm/CustomRMSNorm_output_0 + - /language_model/layers.41/self_attn/q_norm/CustomRMSNorm_output_0 + - /language_model/layers.42/input_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.42/post_attention_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.42/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.42/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.42/self_attn/k_norm/CustomRMSNorm_output_0 + - /language_model/layers.42/self_attn/q_norm/CustomRMSNorm_output_0 + - /language_model/layers.43/input_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.43/post_attention_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.43/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.43/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.43/self_attn/k_norm/CustomRMSNorm_output_0 + - /language_model/layers.43/self_attn/q_norm/CustomRMSNorm_output_0 + - /language_model/layers.44/input_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.44/post_attention_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.44/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.44/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.44/self_attn/k_norm/CustomRMSNorm_output_0 + - /language_model/layers.44/self_attn/q_norm/CustomRMSNorm_output_0 + - /language_model/layers.45/input_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.45/post_attention_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.45/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.45/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.45/self_attn/k_norm/CustomRMSNorm_output_0 + - /language_model/layers.45/self_attn/q_norm/CustomRMSNorm_output_0 + - /language_model/layers.46/input_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.46/post_attention_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.46/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.46/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.46/self_attn/k_norm/CustomRMSNorm_output_0 + - /language_model/layers.46/self_attn/q_norm/CustomRMSNorm_output_0 + - /language_model/layers.47/input_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.47/post_attention_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.47/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.47/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.47/self_attn/k_norm/CustomRMSNorm_output_0 + - /language_model/layers.47/self_attn/q_norm/CustomRMSNorm_output_0 + - /language_model/layers.48/input_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.48/post_attention_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.48/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.48/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.48/self_attn/k_norm/CustomRMSNorm_output_0 + - /language_model/layers.48/self_attn/q_norm/CustomRMSNorm_output_0 + - /language_model/layers.49/input_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.49/post_attention_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.49/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.49/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.49/self_attn/k_norm/CustomRMSNorm_output_0 + - /language_model/layers.49/self_attn/q_norm/CustomRMSNorm_output_0 + - /language_model/layers.50/input_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.50/post_attention_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.50/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.50/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.50/self_attn/k_norm/CustomRMSNorm_output_0 + - /language_model/layers.50/self_attn/q_norm/CustomRMSNorm_output_0 + - /language_model/layers.51/input_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.51/post_attention_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.51/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.51/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.51/self_attn/k_norm/CustomRMSNorm_output_0 + - /language_model/layers.51/self_attn/q_norm/CustomRMSNorm_output_0 + - /language_model/layers.52/input_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.52/post_attention_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.52/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.52/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.52/self_attn/k_norm/CustomRMSNorm_output_0 + - /language_model/layers.52/self_attn/q_norm/CustomRMSNorm_output_0 + - /language_model/layers.53/input_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.53/post_attention_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.53/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.53/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.53/self_attn/k_norm/CustomRMSNorm_output_0 + - /language_model/layers.53/self_attn/q_norm/CustomRMSNorm_output_0 + - /language_model/layers.54/input_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.54/post_attention_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.54/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.54/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.54/self_attn/k_norm/CustomRMSNorm_output_0 + - /language_model/layers.54/self_attn/q_norm/CustomRMSNorm_output_0 + - /language_model/layers.55/input_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.55/post_attention_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.55/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.55/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.55/self_attn/k_norm/CustomRMSNorm_output_0 + - /language_model/layers.55/self_attn/q_norm/CustomRMSNorm_output_0 + - /language_model/layers.56/input_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.56/post_attention_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.56/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.56/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.56/self_attn/k_norm/CustomRMSNorm_output_0 + - /language_model/layers.56/self_attn/q_norm/CustomRMSNorm_output_0 + - /language_model/layers.57/input_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.57/post_attention_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.57/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.57/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.57/self_attn/k_norm/CustomRMSNorm_output_0 + - /language_model/layers.57/self_attn/q_norm/CustomRMSNorm_output_0 + - /language_model/layers.58/input_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.58/post_attention_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.58/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.58/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.58/self_attn/k_norm/CustomRMSNorm_output_0 + - /language_model/layers.58/self_attn/q_norm/CustomRMSNorm_output_0 + - /language_model/layers.59/input_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.59/post_attention_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.59/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.59/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.59/self_attn/k_norm/CustomRMSNorm_output_0 + - /language_model/layers.59/self_attn/q_norm/CustomRMSNorm_output_0 + - /language_model/layers.60/input_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.60/post_attention_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.60/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.60/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.60/self_attn/k_norm/CustomRMSNorm_output_0 + - /language_model/layers.60/self_attn/q_norm/CustomRMSNorm_output_0 + - /language_model/layers.61/input_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.61/post_attention_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.61/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.61/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.61/self_attn/k_norm/CustomRMSNorm_output_0 + - /language_model/layers.61/self_attn/q_norm/CustomRMSNorm_output_0 + - /language_model/norm/CustomRMSNorm_output_0 diff --git a/QEfficient/transformers/models/modeling_auto.py b/QEfficient/transformers/models/modeling_auto.py index b091eea4a..c111f2f73 100644 --- a/QEfficient/transformers/models/modeling_auto.py +++ b/QEfficient/transformers/models/modeling_auto.py @@ -2848,18 +2848,22 @@ def export( self.model.config, fbs if self.continuous_batching else bs, seq_len ) enable_chunking = kwargs.get("enable_chunking", False) - - # TODO: move this to a DA Serving utility class if self.model.config.model_type in SPECIALIZED_DISAGG_SERVING_MODEL_ARCH: if prefill_only: - if self.continuous_batching and not enable_chunking: - raise NotImplementedError("Can't enable prefix-caching without chunking") + if not enable_chunking and self.continuous_batching: + raise NotImplementedError( + "Looks like you are trying to run prefix-caching without chunking, this feature is not available yet!" + ) self.prefill(enable=True, enable_chunking=enable_chunking) self.hash_params.pop("retain_full_kv", None) seq_len = self.get_seq_len_and_handle_specialized_prefill_model( prefill_seq_len=prefill_seq_len, enable_chunking=enable_chunking ) - kv_cache_shape[2] = seq_len + self.model.config.sliding_window if enable_chunking else seq_len + kv_cache_shape[2] = ( + seq_len + (self.model.config.sliding_window if self.model.config.sliding_window is not None else 0) + if enable_chunking + else seq_len + ) else: self.prefill(False, retain_full_kv=kwargs.get("retain_full_kv", False)) self.hash_params.pop("prefill_only", None) @@ -2868,7 +2872,9 @@ def export( self.hash_params.pop("ENABLE_OPT_SWA", None) self.hash_params.pop("chunking", None) if kwargs.get("retain_full_kv", False): - kv_cache_shape[2] = seq_len + self.model.config.sliding_window + kv_cache_shape[2] = seq_len + ( + self.model.config.sliding_window if self.model.config.sliding_window is not None else 0 + ) self.hash_params["retain_full_kv"] = True example_inputs = { @@ -3427,6 +3433,8 @@ def check_and_get_num_speculative_tokens(self, num_speculative_tokens: Optional[ If `num_speculative_tokens` is not an integer greater than 1. If `prefill_seq_len` is less than `num_speculative_tokens + 1`. """ + if not self.is_tlm: + return None if hasattr(self.model.config, "speculative_config"): num_speculative_tokens_ = self.model.config.speculative_config["num_speculative_tokens"] if num_speculative_tokens is not None: diff --git a/QEfficient/transformers/models/pytorch_transforms.py b/QEfficient/transformers/models/pytorch_transforms.py index f946b1de2..f1daf3014 100644 --- a/QEfficient/transformers/models/pytorch_transforms.py +++ b/QEfficient/transformers/models/pytorch_transforms.py @@ -425,6 +425,7 @@ QEffQwen3Model, ) from QEfficient.transformers.models.qwen3_moe.modeling_qwen3_moe import ( + QEffPrefillChunkedQwen3MoeSparseMoeBlock, QEffQwen3MoeAttention, QEffQwen3MoeDecoderLayer, QEffQwen3MoeForCausalLM, @@ -669,19 +670,25 @@ class PrefillOnlyTransform(ModuleMappingTransform): class PrefillOnlyChunkedTransform(ModuleMappingTransform): _module_mapping = { + # GPT_OSS QEffGptOssModel: QEffPrefillOnlyGptOssModel, QEffGptOssAttention: QEffPrefillOnlyChunkedGptOssAttention, QEffGptOssMLP: QEffPrefillOnlyChunkedGptOssMLP, + # Qwen3Moe + QEffQwen3MoeSparseMoeBlock: QEffPrefillChunkedQwen3MoeSparseMoeBlock, } class RevertPrefillKeepAttentionTransform(ModuleMappingTransform): _module_mapping = { + # GPT_OSS QEffGptOssModel: QEffPrefillOnlyGptOssModel, QEffPrefillOnlyGptOssAttention: QEffPrefillOnlyChunkedGptOssAttention, QEffGptOssAttention: QEffPrefillOnlyChunkedGptOssAttention, QEffPrefillOnlyGptOssMLP: QEffGptOssMLP, QEffPrefillOnlyChunkedGptOssMLP: QEffGptOssMLP, + # Qwen3Moe + QEffPrefillChunkedQwen3MoeSparseMoeBlock: QEffQwen3MoeSparseMoeBlock, } diff --git a/QEfficient/transformers/models/qwen3_moe/modeling_qwen3_moe.py b/QEfficient/transformers/models/qwen3_moe/modeling_qwen3_moe.py index d44668c56..6bdd5e243 100644 --- a/QEfficient/transformers/models/qwen3_moe/modeling_qwen3_moe.py +++ b/QEfficient/transformers/models/qwen3_moe/modeling_qwen3_moe.py @@ -104,7 +104,6 @@ def eager_attention_forward( key_states = repeat_kv(key, module.num_key_value_groups) value_states = repeat_kv(value, module.num_key_value_groups) - attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling if attention_mask is not None: attn_weights = torch.where( @@ -118,53 +117,50 @@ def eager_attention_forward( return attn_output, attn_weights -class QEffQwen3MoeSparseMoeBlock(Qwen3MoeSparseMoeBlock): - def __qeff_init__(self): - self.gate_proj_w = [] - self.up_proj_w = [] - self.down_proj_w = [] - with torch.no_grad(): - for e in range(self.num_experts): - self.gate_proj_w.append(self.experts[e].gate_proj.weight.T) - self.up_proj_w.append(self.experts[e].up_proj.weight.T) - self.down_proj_w.append(self.experts[e].down_proj.weight.T) - self.gate_proj_w = torch.stack(self.gate_proj_w) - self.up_proj_w = torch.stack(self.up_proj_w) - self.down_proj_w = torch.stack(self.down_proj_w) - - def alt_forward(self, hidden_states: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: +class QEffPrefillChunkedQwen3MoeSparseMoeBlock(Qwen3MoeSparseMoeBlock): + def forward(self, hidden_states: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: B, S, H = hidden_states.shape T = B * S x = hidden_states.view(T, H) - router_logits = self.gate(x) # [T, E] prob = F.softmax(router_logits, -1, dtype=torch.float) top_w, top_i = torch.topk(prob, self.top_k, -1) if self.norm_topk_prob: # only diff with mixtral sparse moe block! top_w /= top_w.sum(-1, keepdim=True) - top_w = top_w.to(x.dtype) + top_w = top_w.to(hidden_states.dtype) masked_logits = torch.zeros_like(router_logits) masked_logits.scatter_(1, top_i, top_w) - # Routing weights for each expert [T, E] routing_weights = masked_logits - # ────────────────── allocate the output tensor ───── expert_out = x.new_zeros((T, H)) # accumulation buffer - # ───────────────────────── Expert computation loop ───────────────────────────── for e in range(self.num_experts): routing_weight = routing_weights[:, e].unsqueeze(-1) # [T, 1] - W_g, W_u = self.experts[e].gate_proj, self.experts[e].up_proj # [H, I], [H, I] - W_d = self.experts[e].down_proj # [I, H] - gate = W_g(x) # [T, I] - up = W_u(x) # [T, I] - down = W_d(up * self.experts[e].act_fn(gate)) # [T, H] - - masked_down = torch.where(routing_weight > 0, down * routing_weight, torch.zeros_like(expert_out)) + W_g, W_u = self.experts[e].gate_proj.weight.T, self.experts[e].up_proj.weight.T # [H, I], [H, I] + W_d = self.experts[e].down_proj.weight.T # [I, H] + gate = x @ W_g # [T, I] + up = x @ W_u # [T, I] + down = (up * self.experts[e].act_fn(gate)) @ W_d # [T, H] + masked_down = down * routing_weight expert_out += masked_down return expert_out.view(B, S, H), router_logits + +class QEffQwen3MoeSparseMoeBlock(Qwen3MoeSparseMoeBlock): + def __qeff_init__(self): + self.gate_proj_w = [] + self.up_proj_w = [] + self.down_proj_w = [] + with torch.no_grad(): + for e in range(self.num_experts): + self.gate_proj_w.append(self.experts[e].gate_proj.weight.T) + self.up_proj_w.append(self.experts[e].up_proj.weight.T) + self.down_proj_w.append(self.experts[e].down_proj.weight.T) + self.gate_proj_w = torch.stack(self.gate_proj_w) + self.up_proj_w = torch.stack(self.up_proj_w) + self.down_proj_w = torch.stack(self.down_proj_w) + def forward(self, hidden_states: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: B, S, H = hidden_states.shape T = B * S diff --git a/docs/source/quick_start.md b/docs/source/quick_start.md index f15d8de2f..91f351ff5 100644 --- a/docs/source/quick_start.md +++ b/docs/source/quick_start.md @@ -111,6 +111,7 @@ This is the single e2e CLI API, which takes `model_card` name as input along wit * HuggingFace model files Download → Optimize for Cloud AI 100 → Export to `ONNX` → Compile on Cloud AI 100 → [Execute](#execute_api) * It skips the export/compile stage based if `ONNX` or `qpc` files are found. If you use infer second time with different compilation arguments, it will automatically skip `ONNX` model creation and directly jump to compile stage. +* ONNX subfunctions can be enabled explicitly using `--use-onnx-subfunctions`. ```bash @@ -118,6 +119,11 @@ This is the single e2e CLI API, which takes `model_card` name as input along wit python -m QEfficient.cloud.infer --help python -m QEfficient.cloud.infer --model_name gpt2 --batch_size 1 --prompt_len 32 --ctx_len 128 --mxfp6 --num_cores 16 --device_group [0] --prompt "My name is" --mos 1 --aic_enable_depth_first ``` + +```bash +# Optional: explicitly control ONNX subfunction usage +python -m QEfficient.cloud.infer --model_name Qwen/Qwen3-30B-A3B-Instruct-2507 --batch_size 1 --prompt_len 32 --ctx_len 128 --num_cores 16 --device_group [0] --prompt "My name is" --use-onnx-subfunctions +``` If executing for batch size>1, You can pass input prompts in single string but separate with pipe (|) symbol". Example below diff --git a/examples/disagg_serving/qwen3moe_disagg_mode_with_chunking.py b/examples/disagg_serving/qwen3moe_disagg_mode_with_chunking.py new file mode 100644 index 000000000..655de4ef5 --- /dev/null +++ b/examples/disagg_serving/qwen3moe_disagg_mode_with_chunking.py @@ -0,0 +1,133 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +# SPDX-License-Identifier: BSD-3-Clause +# +# ----------------------------------------------------------------------------- + +import time + +import numpy as np +import torch +from transformers import AutoConfig, AutoTokenizer + +from QEfficient import QEFFAutoModelForCausalLM +from QEfficient.generation.cloud_infer import QAICInferenceSession + +model_id = "Qwen/Qwen3-30B-A3B-Instruct-2507" # weights are not required to convert to fp32 +prompt = """ +Explain quantum computing in simple terms. +""" +config = AutoConfig.from_pretrained(model_id) +tokenizer = AutoTokenizer.from_pretrained(model_id) +PREFILL_SEQ_LEN = 128 +CTX_LEN = 128 * 3 + +qeff_model = QEFFAutoModelForCausalLM.from_pretrained(model_id) +decode_qpc_path = qeff_model.compile( + prefill_seq_len=1, + ctx_len=CTX_LEN, + num_cores=16, + mxfp6_matmul=True, + mxint8_kv_cache=True, + num_devices=1, + mos=1, + aic_enable_depth_first=True, + num_speculative_tokens=None, + offload_pt_weights=False, # Need the weights in memory for prefill-model export/compilation in the next step + retain_full_kv=True, +) + +# Following command errors out by default, the user is supposed to run the printed command and provide the generated qpc path as prefill_qpc_path commenting out lines 55-68 + +# prefill_qpc_path = "" + +prefill_qpc_path = qeff_model.compile( + prefill_seq_len=PREFILL_SEQ_LEN, + ctx_len=CTX_LEN, + num_cores=16, + mxfp6_matmul=True, + mxint8_kv_cache=True, + num_devices=2, + split_retained_state_io=True, + mos=1, + aic_enable_depth_first=True, + num_speculative_tokens=None, + prefill_only=True, + enable_chunking=True, + # use_onnx_subfunctions=True, +) + + +inputs = tokenizer(prompt, return_tensors="np", padding=True) +position_ids = inputs["attention_mask"].sum(1, keepdims=True) +generation_len = CTX_LEN - position_ids.max() +padded_len = inputs["input_ids"].shape[1] +num_chunks = -(padded_len // -PREFILL_SEQ_LEN) # ceil divide without float +padded_len = num_chunks * PREFILL_SEQ_LEN # Convert to a multiple of prompt_len +inputs = tokenizer(prompt, return_tensors="np", padding="max_length", max_length=padded_len) +inputs["position_ids"] = np.where(inputs.pop("attention_mask"), np.arange(padded_len), -1) +inputs.pop("token_type_ids", None) +inputs = {k: torch.from_numpy(v) for k, v in inputs.items()} +inputs.pop("past_key_values", None) +inputs = {k: v.detach().numpy() for k, v in inputs.items()} + + +prefill_session = QAICInferenceSession(prefill_qpc_path) +decode_session = QAICInferenceSession(decode_qpc_path) + +all_outputs = [] +for i in range(num_chunks): + chunk_inputs = inputs.copy() + chunk_inputs["input_ids"] = inputs["input_ids"][:, i * PREFILL_SEQ_LEN : (i + 1) * PREFILL_SEQ_LEN] + chunk_inputs["position_ids"] = inputs["position_ids"][:, i * PREFILL_SEQ_LEN : (i + 1) * PREFILL_SEQ_LEN] + ins = time.time() + qpc_out = prefill_session.run(chunk_inputs) + print(f"time for this run={time.time() - ins}") + for i in range(config.num_hidden_layers): + inputs[f"past_key.{i}"] = qpc_out[f"past_key.{i}_RetainedState"] + inputs[f"past_value.{i}"] = qpc_out[f"past_value.{i}_RetainedState"] + +all_outputs.append(np.argmax(qpc_out["logits"])) + +decode_inputs = { + "input_ids": np.argmax(qpc_out["logits"]).reshape(1, 1), + "position_ids": np.max(inputs["position_ids"]).reshape(1, 1) + 1, +} +for i in range(config.num_hidden_layers): + decode_inputs[f"past_key.{i}"] = qpc_out[f"past_key.{i}_RetainedState"] + decode_inputs[f"past_value.{i}"] = qpc_out[f"past_value.{i}_RetainedState"] + +st = time.time() +decode_out = decode_session.run(decode_inputs) +print(f"time for first run of decode with KV as input = {time.time() - st} sec\n") +all_outputs.append(np.argmax(decode_out["logits"])) +pos_id = np.max(decode_inputs["position_ids"]).reshape(1, 1) + 1 +loop_decode_inputs = { + "input_ids": np.argmax(decode_out["logits"]).reshape(1, 1), + "position_ids": pos_id, +} + +for i in range(config.num_hidden_layers): + loop_decode_inputs[f"past_key.{i}"] = decode_out[f"past_key.{i}_RetainedState"] + loop_decode_inputs[f"past_value.{i}"] = decode_out[f"past_value.{i}_RetainedState"] + +st = time.time() +for i in range(generation_len - 2): + decode_out = decode_session.run(loop_decode_inputs) + all_outputs.append(np.argmax(decode_out["logits"])) + pos_id += 1 + for i in range(config.num_hidden_layers): + loop_decode_inputs[f"past_key.{i}"] = decode_out[f"past_key.{i}_RetainedState"] + loop_decode_inputs[f"past_value.{i}"] = decode_out[f"past_value.{i}_RetainedState"] + + loop_decode_inputs.update( + { + "input_ids": np.argmax(decode_out["logits"]).reshape(1, 1), + "position_ids": pos_id, + } + ) +ft = time.time() + +print(f"decode tok/sec={(generation_len - 2) / (ft - st)}") +print(f"input\n{prompt}\noutput\n{tokenizer.decode(all_outputs)}") diff --git a/examples/image_text_to_text/models/gemma_vision/README.md b/examples/image_text_to_text/models/gemma_vision/README.md new file mode 100644 index 000000000..448f0a9eb --- /dev/null +++ b/examples/image_text_to_text/models/gemma_vision/README.md @@ -0,0 +1,40 @@ +# Gemma3 NPI Files + +a) For Gemma3-4B model user is adviced to use the NPI file namely fp32_nodes_gemma3_4b.yaml + example compile command - + npi_file_path = "configs/fp32_nodes_gemma3_4b.yaml" + npi_file_full_path = os.path.join(os.getcwd(), npi_file_path) + + qeff_model.compile( + prefill_seq_len=128, + ctx_len=3072, + img_size=896, + num_cores=16, + num_devices=1, + mxfp6_matmul=False, + mxint8_kv_cache=False, + aic_enable_depth_first=True, + skip_vision=True, + mos=1, + node_precision_info=npi_file_full_path + ) + +b) For Gemma3-27B model user is adviced to use the NPI file namely gemma_updated_npi.yaml + + example compile command - + npi_file_path = "configs/gemma_updated_npi.yaml" + npi_file_full_path = os.path.join(os.getcwd(), npi_file_path) + + qeff_model.compile( + prefill_seq_len=128, + ctx_len=3072, + img_size=896, + num_cores=16, + num_devices=1, + mxfp6_matmul=False, + mxint8_kv_cache=False, + aic_enable_depth_first=True, + skip_vision=True, + mos=1, + node_precision_info=npi_file_full_path + ) \ No newline at end of file diff --git a/examples/image_text_to_text/models/gemma_vision/configs/gemma_updated_npi.yaml b/examples/image_text_to_text/models/gemma_vision/configs/gemma_updated_npi.yaml new file mode 100644 index 000000000..faf4f9d72 --- /dev/null +++ b/examples/image_text_to_text/models/gemma_vision/configs/gemma_updated_npi.yaml @@ -0,0 +1,1564 @@ +FP16NodeInstanceNames: + - /lm_head/MatMul_output_0 + - onnx::MatMul_25530 + +FP32NodeInstanceNames: + + + #Mul + - /language_model/layers.0/mlp/act_fn/Mul_output_0 + - /language_model/layers.0/mlp/act_fn/Mul_1_output_0 + - /language_model/layers.0/mlp/act_fn/Mul_2_output_0 + - /language_model/layers.0/mlp/act_fn/Mul_3_output_0 + - /language_model/layers.0/mlp/act_fn/Mul_4_output_0 + - /language_model/layers.0/mlp/act_fn/Mul_5_output_0 + - /language_model/layers.1/mlp/act_fn/Mul_output_0 + - /language_model/layers.1/mlp/act_fn/Mul_1_output_0 + - /language_model/layers.1/mlp/act_fn/Mul_2_output_0 + - /language_model/layers.1/mlp/act_fn/Mul_3_output_0 + - /language_model/layers.1/mlp/act_fn/Mul_4_output_0 + - /language_model/layers.1/mlp/act_fn/Mul_5_output_0 + - /language_model/layers.2/mlp/act_fn/Mul_output_0 + - /language_model/layers.2/mlp/act_fn/Mul_1_output_0 + - /language_model/layers.2/mlp/act_fn/Mul_2_output_0 + - /language_model/layers.2/mlp/act_fn/Mul_3_output_0 + - /language_model/layers.2/mlp/act_fn/Mul_4_output_0 + - /language_model/layers.2/mlp/act_fn/Mul_5_output_0 + - /language_model/layers.3/mlp/act_fn/Mul_output_0 + - /language_model/layers.3/mlp/act_fn/Mul_1_output_0 + - /language_model/layers.3/mlp/act_fn/Mul_2_output_0 + - /language_model/layers.3/mlp/act_fn/Mul_3_output_0 + - /language_model/layers.3/mlp/act_fn/Mul_4_output_0 + - /language_model/layers.3/mlp/act_fn/Mul_5_output_0 + - /language_model/layers.4/mlp/act_fn/Mul_output_0 + - /language_model/layers.4/mlp/act_fn/Mul_1_output_0 + - /language_model/layers.4/mlp/act_fn/Mul_2_output_0 + - /language_model/layers.4/mlp/act_fn/Mul_3_output_0 + - /language_model/layers.4/mlp/act_fn/Mul_4_output_0 + - /language_model/layers.4/mlp/act_fn/Mul_5_output_0 + - /language_model/layers.5/mlp/act_fn/Mul_output_0 + - /language_model/layers.5/mlp/act_fn/Mul_1_output_0 + - /language_model/layers.5/mlp/act_fn/Mul_2_output_0 + - /language_model/layers.5/mlp/act_fn/Mul_3_output_0 + - /language_model/layers.5/mlp/act_fn/Mul_4_output_0 + - /language_model/layers.5/mlp/act_fn/Mul_5_output_0 + - /language_model/layers.6/mlp/act_fn/Mul_output_0 + - /language_model/layers.6/mlp/act_fn/Mul_1_output_0 + - /language_model/layers.6/mlp/act_fn/Mul_2_output_0 + - /language_model/layers.6/mlp/act_fn/Mul_3_output_0 + - /language_model/layers.6/mlp/act_fn/Mul_4_output_0 + - /language_model/layers.6/mlp/act_fn/Mul_5_output_0 + - /language_model/layers.7/mlp/act_fn/Mul_output_0 + - /language_model/layers.7/mlp/act_fn/Mul_1_output_0 + - /language_model/layers.7/mlp/act_fn/Mul_2_output_0 + - /language_model/layers.7/mlp/act_fn/Mul_3_output_0 + - /language_model/layers.7/mlp/act_fn/Mul_4_output_0 + - /language_model/layers.7/mlp/act_fn/Mul_5_output_0 + - /language_model/layers.8/mlp/act_fn/Mul_output_0 + - /language_model/layers.8/mlp/act_fn/Mul_1_output_0 + - /language_model/layers.8/mlp/act_fn/Mul_2_output_0 + - /language_model/layers.8/mlp/act_fn/Mul_3_output_0 + - /language_model/layers.8/mlp/act_fn/Mul_4_output_0 + - /language_model/layers.8/mlp/act_fn/Mul_5_output_0 + - /language_model/layers.9/mlp/act_fn/Mul_output_0 + - /language_model/layers.9/mlp/act_fn/Mul_1_output_0 + - /language_model/layers.9/mlp/act_fn/Mul_2_output_0 + - /language_model/layers.9/mlp/act_fn/Mul_3_output_0 + - /language_model/layers.9/mlp/act_fn/Mul_4_output_0 + - /language_model/layers.9/mlp/act_fn/Mul_5_output_0 + - /language_model/layers.10/mlp/act_fn/Mul_output_0 + - /language_model/layers.10/mlp/act_fn/Mul_1_output_0 + - /language_model/layers.10/mlp/act_fn/Mul_2_output_0 + - /language_model/layers.10/mlp/act_fn/Mul_3_output_0 + - /language_model/layers.10/mlp/act_fn/Mul_4_output_0 + - /language_model/layers.10/mlp/act_fn/Mul_5_output_0 + - /language_model/layers.11/mlp/act_fn/Mul_output_0 + - /language_model/layers.11/mlp/act_fn/Mul_1_output_0 + - /language_model/layers.11/mlp/act_fn/Mul_2_output_0 + - /language_model/layers.11/mlp/act_fn/Mul_3_output_0 + - /language_model/layers.11/mlp/act_fn/Mul_4_output_0 + - /language_model/layers.11/mlp/act_fn/Mul_5_output_0 + - /language_model/layers.12/mlp/act_fn/Mul_output_0 + - /language_model/layers.12/mlp/act_fn/Mul_1_output_0 + - /language_model/layers.12/mlp/act_fn/Mul_2_output_0 + - /language_model/layers.12/mlp/act_fn/Mul_3_output_0 + - /language_model/layers.12/mlp/act_fn/Mul_4_output_0 + - /language_model/layers.12/mlp/act_fn/Mul_5_output_0 + - /language_model/layers.13/mlp/act_fn/Mul_output_0 + - /language_model/layers.13/mlp/act_fn/Mul_1_output_0 + - /language_model/layers.13/mlp/act_fn/Mul_2_output_0 + - /language_model/layers.13/mlp/act_fn/Mul_3_output_0 + - /language_model/layers.13/mlp/act_fn/Mul_4_output_0 + - /language_model/layers.13/mlp/act_fn/Mul_5_output_0 + - /language_model/layers.14/mlp/act_fn/Mul_output_0 + - /language_model/layers.14/mlp/act_fn/Mul_1_output_0 + - /language_model/layers.14/mlp/act_fn/Mul_2_output_0 + - /language_model/layers.14/mlp/act_fn/Mul_3_output_0 + - /language_model/layers.14/mlp/act_fn/Mul_4_output_0 + - /language_model/layers.14/mlp/act_fn/Mul_5_output_0 + - /language_model/layers.15/mlp/act_fn/Mul_output_0 + - /language_model/layers.15/mlp/act_fn/Mul_1_output_0 + - /language_model/layers.15/mlp/act_fn/Mul_2_output_0 + - /language_model/layers.15/mlp/act_fn/Mul_3_output_0 + - /language_model/layers.15/mlp/act_fn/Mul_4_output_0 + - /language_model/layers.15/mlp/act_fn/Mul_5_output_0 + - /language_model/layers.16/mlp/act_fn/Mul_output_0 + - /language_model/layers.16/mlp/act_fn/Mul_1_output_0 + - /language_model/layers.16/mlp/act_fn/Mul_2_output_0 + - /language_model/layers.16/mlp/act_fn/Mul_3_output_0 + - /language_model/layers.16/mlp/act_fn/Mul_4_output_0 + - /language_model/layers.16/mlp/act_fn/Mul_5_output_0 + - /language_model/layers.17/mlp/act_fn/Mul_output_0 + - /language_model/layers.17/mlp/act_fn/Mul_1_output_0 + - /language_model/layers.17/mlp/act_fn/Mul_2_output_0 + - /language_model/layers.17/mlp/act_fn/Mul_3_output_0 + - /language_model/layers.17/mlp/act_fn/Mul_4_output_0 + - /language_model/layers.17/mlp/act_fn/Mul_5_output_0 + - /language_model/layers.18/mlp/act_fn/Mul_output_0 + - /language_model/layers.18/mlp/act_fn/Mul_1_output_0 + - /language_model/layers.18/mlp/act_fn/Mul_2_output_0 + - /language_model/layers.18/mlp/act_fn/Mul_3_output_0 + - /language_model/layers.18/mlp/act_fn/Mul_4_output_0 + - /language_model/layers.18/mlp/act_fn/Mul_5_output_0 + - /language_model/layers.19/mlp/act_fn/Mul_output_0 + - /language_model/layers.19/mlp/act_fn/Mul_1_output_0 + - /language_model/layers.19/mlp/act_fn/Mul_2_output_0 + - /language_model/layers.19/mlp/act_fn/Mul_3_output_0 + - /language_model/layers.19/mlp/act_fn/Mul_4_output_0 + - /language_model/layers.19/mlp/act_fn/Mul_5_output_0 + - /language_model/layers.20/mlp/act_fn/Mul_output_0 + - /language_model/layers.20/mlp/act_fn/Mul_1_output_0 + - /language_model/layers.20/mlp/act_fn/Mul_2_output_0 + - /language_model/layers.20/mlp/act_fn/Mul_3_output_0 + - /language_model/layers.20/mlp/act_fn/Mul_4_output_0 + - /language_model/layers.20/mlp/act_fn/Mul_5_output_0 + - /language_model/layers.21/mlp/act_fn/Mul_output_0 + - /language_model/layers.21/mlp/act_fn/Mul_1_output_0 + - /language_model/layers.21/mlp/act_fn/Mul_2_output_0 + - /language_model/layers.21/mlp/act_fn/Mul_3_output_0 + - /language_model/layers.21/mlp/act_fn/Mul_4_output_0 + - /language_model/layers.21/mlp/act_fn/Mul_5_output_0 + - /language_model/layers.22/mlp/act_fn/Mul_output_0 + - /language_model/layers.22/mlp/act_fn/Mul_1_output_0 + - /language_model/layers.22/mlp/act_fn/Mul_2_output_0 + - /language_model/layers.22/mlp/act_fn/Mul_3_output_0 + - /language_model/layers.22/mlp/act_fn/Mul_4_output_0 + - /language_model/layers.22/mlp/act_fn/Mul_5_output_0 + - /language_model/layers.23/mlp/act_fn/Mul_output_0 + - /language_model/layers.23/mlp/act_fn/Mul_1_output_0 + - /language_model/layers.23/mlp/act_fn/Mul_2_output_0 + - /language_model/layers.23/mlp/act_fn/Mul_3_output_0 + - /language_model/layers.23/mlp/act_fn/Mul_4_output_0 + - /language_model/layers.23/mlp/act_fn/Mul_5_output_0 + - /language_model/layers.24/mlp/act_fn/Mul_output_0 + - /language_model/layers.24/mlp/act_fn/Mul_1_output_0 + - /language_model/layers.24/mlp/act_fn/Mul_2_output_0 + - /language_model/layers.24/mlp/act_fn/Mul_3_output_0 + - /language_model/layers.24/mlp/act_fn/Mul_4_output_0 + - /language_model/layers.24/mlp/act_fn/Mul_5_output_0 + - /language_model/layers.25/mlp/act_fn/Mul_output_0 + - /language_model/layers.25/mlp/act_fn/Mul_1_output_0 + - /language_model/layers.25/mlp/act_fn/Mul_2_output_0 + - /language_model/layers.25/mlp/act_fn/Mul_3_output_0 + - /language_model/layers.25/mlp/act_fn/Mul_4_output_0 + - /language_model/layers.25/mlp/act_fn/Mul_5_output_0 + - /language_model/layers.26/mlp/act_fn/Mul_output_0 + - /language_model/layers.26/mlp/act_fn/Mul_1_output_0 + - /language_model/layers.26/mlp/act_fn/Mul_2_output_0 + - /language_model/layers.26/mlp/act_fn/Mul_3_output_0 + - /language_model/layers.26/mlp/act_fn/Mul_4_output_0 + - /language_model/layers.26/mlp/act_fn/Mul_5_output_0 + - /language_model/layers.27/mlp/act_fn/Mul_output_0 + - /language_model/layers.27/mlp/act_fn/Mul_1_output_0 + - /language_model/layers.27/mlp/act_fn/Mul_2_output_0 + - /language_model/layers.27/mlp/act_fn/Mul_3_output_0 + - /language_model/layers.27/mlp/act_fn/Mul_4_output_0 + - /language_model/layers.27/mlp/act_fn/Mul_5_output_0 + - /language_model/layers.28/mlp/act_fn/Mul_output_0 + - /language_model/layers.28/mlp/act_fn/Mul_1_output_0 + - /language_model/layers.28/mlp/act_fn/Mul_2_output_0 + - /language_model/layers.28/mlp/act_fn/Mul_3_output_0 + - /language_model/layers.28/mlp/act_fn/Mul_4_output_0 + - /language_model/layers.28/mlp/act_fn/Mul_5_output_0 + - /language_model/layers.29/mlp/act_fn/Mul_output_0 + - /language_model/layers.29/mlp/act_fn/Mul_1_output_0 + - /language_model/layers.29/mlp/act_fn/Mul_2_output_0 + - /language_model/layers.29/mlp/act_fn/Mul_3_output_0 + - /language_model/layers.29/mlp/act_fn/Mul_4_output_0 + - /language_model/layers.29/mlp/act_fn/Mul_5_output_0 + - /language_model/layers.30/mlp/act_fn/Mul_output_0 + - /language_model/layers.30/mlp/act_fn/Mul_1_output_0 + - /language_model/layers.30/mlp/act_fn/Mul_2_output_0 + - /language_model/layers.30/mlp/act_fn/Mul_3_output_0 + - /language_model/layers.30/mlp/act_fn/Mul_4_output_0 + - /language_model/layers.30/mlp/act_fn/Mul_5_output_0 + - /language_model/layers.31/mlp/act_fn/Mul_output_0 + - /language_model/layers.31/mlp/act_fn/Mul_1_output_0 + - /language_model/layers.31/mlp/act_fn/Mul_2_output_0 + - /language_model/layers.31/mlp/act_fn/Mul_3_output_0 + - /language_model/layers.31/mlp/act_fn/Mul_4_output_0 + - /language_model/layers.31/mlp/act_fn/Mul_5_output_0 + - /language_model/layers.32/mlp/act_fn/Mul_output_0 + - /language_model/layers.32/mlp/act_fn/Mul_1_output_0 + - /language_model/layers.32/mlp/act_fn/Mul_2_output_0 + - /language_model/layers.32/mlp/act_fn/Mul_3_output_0 + - /language_model/layers.32/mlp/act_fn/Mul_4_output_0 + - /language_model/layers.32/mlp/act_fn/Mul_5_output_0 + - /language_model/layers.33/mlp/act_fn/Mul_output_0 + - /language_model/layers.33/mlp/act_fn/Mul_1_output_0 + - /language_model/layers.33/mlp/act_fn/Mul_2_output_0 + - /language_model/layers.33/mlp/act_fn/Mul_3_output_0 + - /language_model/layers.33/mlp/act_fn/Mul_4_output_0 + - /language_model/layers.33/mlp/act_fn/Mul_5_output_0 + - /language_model/layers.34/mlp/act_fn/Mul_output_0 + - /language_model/layers.34/mlp/act_fn/Mul_1_output_0 + - /language_model/layers.34/mlp/act_fn/Mul_2_output_0 + - /language_model/layers.34/mlp/act_fn/Mul_3_output_0 + - /language_model/layers.34/mlp/act_fn/Mul_4_output_0 + - /language_model/layers.34/mlp/act_fn/Mul_5_output_0 + - /language_model/layers.35/mlp/act_fn/Mul_output_0 + - /language_model/layers.35/mlp/act_fn/Mul_1_output_0 + - /language_model/layers.35/mlp/act_fn/Mul_2_output_0 + - /language_model/layers.35/mlp/act_fn/Mul_3_output_0 + - /language_model/layers.35/mlp/act_fn/Mul_4_output_0 + - /language_model/layers.35/mlp/act_fn/Mul_5_output_0 + - /language_model/layers.36/mlp/act_fn/Mul_output_0 + - /language_model/layers.36/mlp/act_fn/Mul_1_output_0 + - /language_model/layers.36/mlp/act_fn/Mul_2_output_0 + - /language_model/layers.36/mlp/act_fn/Mul_3_output_0 + - /language_model/layers.36/mlp/act_fn/Mul_4_output_0 + - /language_model/layers.36/mlp/act_fn/Mul_5_output_0 + - /language_model/layers.37/mlp/act_fn/Mul_output_0 + - /language_model/layers.37/mlp/act_fn/Mul_1_output_0 + - /language_model/layers.37/mlp/act_fn/Mul_2_output_0 + - /language_model/layers.37/mlp/act_fn/Mul_3_output_0 + - /language_model/layers.37/mlp/act_fn/Mul_4_output_0 + - /language_model/layers.37/mlp/act_fn/Mul_5_output_0 + - /language_model/layers.38/mlp/act_fn/Mul_output_0 + - /language_model/layers.38/mlp/act_fn/Mul_1_output_0 + - /language_model/layers.38/mlp/act_fn/Mul_2_output_0 + - /language_model/layers.38/mlp/act_fn/Mul_3_output_0 + - /language_model/layers.38/mlp/act_fn/Mul_4_output_0 + - /language_model/layers.38/mlp/act_fn/Mul_5_output_0 + - /language_model/layers.39/mlp/act_fn/Mul_output_0 + - /language_model/layers.39/mlp/act_fn/Mul_1_output_0 + - /language_model/layers.39/mlp/act_fn/Mul_2_output_0 + - /language_model/layers.39/mlp/act_fn/Mul_3_output_0 + - /language_model/layers.39/mlp/act_fn/Mul_4_output_0 + - /language_model/layers.39/mlp/act_fn/Mul_5_output_0 + - /language_model/layers.40/mlp/act_fn/Mul_output_0 + - /language_model/layers.40/mlp/act_fn/Mul_1_output_0 + - /language_model/layers.40/mlp/act_fn/Mul_2_output_0 + - /language_model/layers.40/mlp/act_fn/Mul_3_output_0 + - /language_model/layers.40/mlp/act_fn/Mul_4_output_0 + - /language_model/layers.40/mlp/act_fn/Mul_5_output_0 + - /language_model/layers.41/mlp/act_fn/Mul_output_0 + - /language_model/layers.41/mlp/act_fn/Mul_1_output_0 + - /language_model/layers.41/mlp/act_fn/Mul_2_output_0 + - /language_model/layers.41/mlp/act_fn/Mul_3_output_0 + - /language_model/layers.41/mlp/act_fn/Mul_4_output_0 + - /language_model/layers.41/mlp/act_fn/Mul_5_output_0 + - /language_model/layers.42/mlp/act_fn/Mul_output_0 + - /language_model/layers.42/mlp/act_fn/Mul_1_output_0 + - /language_model/layers.42/mlp/act_fn/Mul_2_output_0 + - /language_model/layers.42/mlp/act_fn/Mul_3_output_0 + - /language_model/layers.42/mlp/act_fn/Mul_4_output_0 + - /language_model/layers.42/mlp/act_fn/Mul_5_output_0 + - /language_model/layers.43/mlp/act_fn/Mul_output_0 + - /language_model/layers.43/mlp/act_fn/Mul_1_output_0 + - /language_model/layers.43/mlp/act_fn/Mul_2_output_0 + - /language_model/layers.43/mlp/act_fn/Mul_3_output_0 + - /language_model/layers.43/mlp/act_fn/Mul_4_output_0 + - /language_model/layers.43/mlp/act_fn/Mul_5_output_0 + - /language_model/layers.44/mlp/act_fn/Mul_output_0 + - /language_model/layers.44/mlp/act_fn/Mul_1_output_0 + - /language_model/layers.44/mlp/act_fn/Mul_2_output_0 + - /language_model/layers.44/mlp/act_fn/Mul_3_output_0 + - /language_model/layers.44/mlp/act_fn/Mul_4_output_0 + - /language_model/layers.44/mlp/act_fn/Mul_5_output_0 + - /language_model/layers.45/mlp/act_fn/Mul_output_0 + - /language_model/layers.45/mlp/act_fn/Mul_1_output_0 + - /language_model/layers.45/mlp/act_fn/Mul_2_output_0 + - /language_model/layers.45/mlp/act_fn/Mul_3_output_0 + - /language_model/layers.45/mlp/act_fn/Mul_4_output_0 + - /language_model/layers.45/mlp/act_fn/Mul_5_output_0 + - /language_model/layers.46/mlp/act_fn/Mul_output_0 + - /language_model/layers.46/mlp/act_fn/Mul_1_output_0 + - /language_model/layers.46/mlp/act_fn/Mul_2_output_0 + - /language_model/layers.46/mlp/act_fn/Mul_3_output_0 + - /language_model/layers.46/mlp/act_fn/Mul_4_output_0 + - /language_model/layers.46/mlp/act_fn/Mul_5_output_0 + - /language_model/layers.47/mlp/act_fn/Mul_output_0 + - /language_model/layers.47/mlp/act_fn/Mul_1_output_0 + - /language_model/layers.47/mlp/act_fn/Mul_2_output_0 + - /language_model/layers.47/mlp/act_fn/Mul_3_output_0 + - /language_model/layers.47/mlp/act_fn/Mul_4_output_0 + - /language_model/layers.47/mlp/act_fn/Mul_5_output_0 + - /language_model/layers.48/mlp/act_fn/Mul_output_0 + - /language_model/layers.48/mlp/act_fn/Mul_1_output_0 + - /language_model/layers.48/mlp/act_fn/Mul_2_output_0 + - /language_model/layers.48/mlp/act_fn/Mul_3_output_0 + - /language_model/layers.48/mlp/act_fn/Mul_4_output_0 + - /language_model/layers.48/mlp/act_fn/Mul_5_output_0 + - /language_model/layers.49/mlp/act_fn/Mul_output_0 + - /language_model/layers.49/mlp/act_fn/Mul_1_output_0 + - /language_model/layers.49/mlp/act_fn/Mul_2_output_0 + - /language_model/layers.49/mlp/act_fn/Mul_3_output_0 + - /language_model/layers.49/mlp/act_fn/Mul_4_output_0 + - /language_model/layers.49/mlp/act_fn/Mul_5_output_0 + - /language_model/layers.50/mlp/act_fn/Mul_output_0 + - /language_model/layers.50/mlp/act_fn/Mul_1_output_0 + - /language_model/layers.50/mlp/act_fn/Mul_2_output_0 + - /language_model/layers.50/mlp/act_fn/Mul_3_output_0 + - /language_model/layers.50/mlp/act_fn/Mul_4_output_0 + - /language_model/layers.50/mlp/act_fn/Mul_5_output_0 + - /language_model/layers.51/mlp/act_fn/Mul_output_0 + - /language_model/layers.51/mlp/act_fn/Mul_1_output_0 + - /language_model/layers.51/mlp/act_fn/Mul_2_output_0 + - /language_model/layers.51/mlp/act_fn/Mul_3_output_0 + - /language_model/layers.51/mlp/act_fn/Mul_4_output_0 + - /language_model/layers.51/mlp/act_fn/Mul_5_output_0 + - /language_model/layers.52/mlp/act_fn/Mul_output_0 + - /language_model/layers.52/mlp/act_fn/Mul_1_output_0 + - /language_model/layers.52/mlp/act_fn/Mul_2_output_0 + - /language_model/layers.52/mlp/act_fn/Mul_3_output_0 + - /language_model/layers.52/mlp/act_fn/Mul_4_output_0 + - /language_model/layers.52/mlp/act_fn/Mul_5_output_0 + - /language_model/layers.53/mlp/act_fn/Mul_output_0 + - /language_model/layers.53/mlp/act_fn/Mul_1_output_0 + - /language_model/layers.53/mlp/act_fn/Mul_2_output_0 + - /language_model/layers.53/mlp/act_fn/Mul_3_output_0 + - /language_model/layers.53/mlp/act_fn/Mul_4_output_0 + - /language_model/layers.53/mlp/act_fn/Mul_5_output_0 + - /language_model/layers.54/mlp/act_fn/Mul_output_0 + - /language_model/layers.54/mlp/act_fn/Mul_1_output_0 + - /language_model/layers.54/mlp/act_fn/Mul_2_output_0 + - /language_model/layers.54/mlp/act_fn/Mul_3_output_0 + - /language_model/layers.54/mlp/act_fn/Mul_4_output_0 + - /language_model/layers.54/mlp/act_fn/Mul_5_output_0 + - /language_model/layers.55/mlp/act_fn/Mul_output_0 + - /language_model/layers.55/mlp/act_fn/Mul_1_output_0 + - /language_model/layers.55/mlp/act_fn/Mul_2_output_0 + - /language_model/layers.55/mlp/act_fn/Mul_3_output_0 + - /language_model/layers.55/mlp/act_fn/Mul_4_output_0 + - /language_model/layers.55/mlp/act_fn/Mul_5_output_0 + - /language_model/layers.56/mlp/act_fn/Mul_output_0 + - /language_model/layers.56/mlp/act_fn/Mul_1_output_0 + - /language_model/layers.56/mlp/act_fn/Mul_2_output_0 + - /language_model/layers.56/mlp/act_fn/Mul_3_output_0 + - /language_model/layers.56/mlp/act_fn/Mul_4_output_0 + - /language_model/layers.56/mlp/act_fn/Mul_5_output_0 + - /language_model/layers.57/mlp/act_fn/Mul_output_0 + - /language_model/layers.57/mlp/act_fn/Mul_1_output_0 + - /language_model/layers.57/mlp/act_fn/Mul_2_output_0 + - /language_model/layers.57/mlp/act_fn/Mul_3_output_0 + - /language_model/layers.57/mlp/act_fn/Mul_4_output_0 + - /language_model/layers.57/mlp/act_fn/Mul_5_output_0 + - /language_model/layers.58/mlp/act_fn/Mul_output_0 + - /language_model/layers.58/mlp/act_fn/Mul_1_output_0 + - /language_model/layers.58/mlp/act_fn/Mul_2_output_0 + - /language_model/layers.58/mlp/act_fn/Mul_3_output_0 + - /language_model/layers.58/mlp/act_fn/Mul_4_output_0 + - /language_model/layers.58/mlp/act_fn/Mul_5_output_0 + - /language_model/layers.59/mlp/act_fn/Mul_output_0 + - /language_model/layers.59/mlp/act_fn/Mul_1_output_0 + - /language_model/layers.59/mlp/act_fn/Mul_2_output_0 + - /language_model/layers.59/mlp/act_fn/Mul_3_output_0 + - /language_model/layers.59/mlp/act_fn/Mul_4_output_0 + - /language_model/layers.59/mlp/act_fn/Mul_5_output_0 + - /language_model/layers.60/mlp/act_fn/Mul_output_0 + - /language_model/layers.60/mlp/act_fn/Mul_1_output_0 + - /language_model/layers.60/mlp/act_fn/Mul_2_output_0 + - /language_model/layers.60/mlp/act_fn/Mul_3_output_0 + - /language_model/layers.60/mlp/act_fn/Mul_4_output_0 + - /language_model/layers.60/mlp/act_fn/Mul_5_output_0 + - /language_model/layers.61/mlp/act_fn/Mul_output_0 + - /language_model/layers.61/mlp/act_fn/Mul_1_output_0 + - /language_model/layers.61/mlp/act_fn/Mul_2_output_0 + - /language_model/layers.61/mlp/act_fn/Mul_3_output_0 + - /language_model/layers.61/mlp/act_fn/Mul_4_output_0 + - /language_model/layers.61/mlp/act_fn/Mul_5_output_0 + + #Constant + - /language_model/layers.0/mlp/act_fn/Constant_output_0 + - /language_model/layers.0/mlp/act_fn/Constant_1_output_0 + - /language_model/layers.0/mlp/act_fn/Constant_2_output_0 + - /language_model/layers.0/mlp/act_fn/Constant_3_output_0 + - /language_model/layers.1/mlp/act_fn/Constant_output_0 + - /language_model/layers.1/mlp/act_fn/Constant_1_output_0 + - /language_model/layers.1/mlp/act_fn/Constant_2_output_0 + - /language_model/layers.1/mlp/act_fn/Constant_3_output_0 + - /language_model/layers.2/mlp/act_fn/Constant_output_0 + - /language_model/layers.2/mlp/act_fn/Constant_1_output_0 + - /language_model/layers.2/mlp/act_fn/Constant_2_output_0 + - /language_model/layers.2/mlp/act_fn/Constant_3_output_0 + - /language_model/layers.3/mlp/act_fn/Constant_output_0 + - /language_model/layers.3/mlp/act_fn/Constant_1_output_0 + - /language_model/layers.3/mlp/act_fn/Constant_2_output_0 + - /language_model/layers.3/mlp/act_fn/Constant_3_output_0 + - /language_model/layers.4/mlp/act_fn/Constant_output_0 + - /language_model/layers.4/mlp/act_fn/Constant_1_output_0 + - /language_model/layers.4/mlp/act_fn/Constant_2_output_0 + - /language_model/layers.4/mlp/act_fn/Constant_3_output_0 + - /language_model/layers.5/mlp/act_fn/Constant_output_0 + - /language_model/layers.5/mlp/act_fn/Constant_1_output_0 + - /language_model/layers.5/mlp/act_fn/Constant_2_output_0 + - /language_model/layers.5/mlp/act_fn/Constant_3_output_0 + - /language_model/layers.6/mlp/act_fn/Constant_output_0 + - /language_model/layers.6/mlp/act_fn/Constant_1_output_0 + - /language_model/layers.6/mlp/act_fn/Constant_2_output_0 + - /language_model/layers.6/mlp/act_fn/Constant_3_output_0 + - /language_model/layers.7/mlp/act_fn/Constant_output_0 + - /language_model/layers.7/mlp/act_fn/Constant_1_output_0 + - /language_model/layers.7/mlp/act_fn/Constant_2_output_0 + - /language_model/layers.7/mlp/act_fn/Constant_3_output_0 + - /language_model/layers.8/mlp/act_fn/Constant_output_0 + - /language_model/layers.8/mlp/act_fn/Constant_1_output_0 + - /language_model/layers.8/mlp/act_fn/Constant_2_output_0 + - /language_model/layers.8/mlp/act_fn/Constant_3_output_0 + - /language_model/layers.9/mlp/act_fn/Constant_output_0 + - /language_model/layers.9/mlp/act_fn/Constant_1_output_0 + - /language_model/layers.9/mlp/act_fn/Constant_2_output_0 + - /language_model/layers.9/mlp/act_fn/Constant_3_output_0 + - /language_model/layers.10/mlp/act_fn/Constant_output_0 + - /language_model/layers.10/mlp/act_fn/Constant_1_output_0 + - /language_model/layers.10/mlp/act_fn/Constant_2_output_0 + - /language_model/layers.10/mlp/act_fn/Constant_3_output_0 + - /language_model/layers.11/mlp/act_fn/Constant_output_0 + - /language_model/layers.11/mlp/act_fn/Constant_1_output_0 + - /language_model/layers.11/mlp/act_fn/Constant_2_output_0 + - /language_model/layers.11/mlp/act_fn/Constant_3_output_0 + - /language_model/layers.12/mlp/act_fn/Constant_output_0 + - /language_model/layers.12/mlp/act_fn/Constant_1_output_0 + - /language_model/layers.12/mlp/act_fn/Constant_2_output_0 + - /language_model/layers.12/mlp/act_fn/Constant_3_output_0 + - /language_model/layers.13/mlp/act_fn/Constant_output_0 + - /language_model/layers.13/mlp/act_fn/Constant_1_output_0 + - /language_model/layers.13/mlp/act_fn/Constant_2_output_0 + - /language_model/layers.13/mlp/act_fn/Constant_3_output_0 + - /language_model/layers.14/mlp/act_fn/Constant_output_0 + - /language_model/layers.14/mlp/act_fn/Constant_1_output_0 + - /language_model/layers.14/mlp/act_fn/Constant_2_output_0 + - /language_model/layers.14/mlp/act_fn/Constant_3_output_0 + - /language_model/layers.15/mlp/act_fn/Constant_output_0 + - /language_model/layers.15/mlp/act_fn/Constant_1_output_0 + - /language_model/layers.15/mlp/act_fn/Constant_2_output_0 + - /language_model/layers.15/mlp/act_fn/Constant_3_output_0 + - /language_model/layers.16/mlp/act_fn/Constant_output_0 + - /language_model/layers.16/mlp/act_fn/Constant_1_output_0 + - /language_model/layers.16/mlp/act_fn/Constant_2_output_0 + - /language_model/layers.16/mlp/act_fn/Constant_3_output_0 + - /language_model/layers.17/mlp/act_fn/Constant_output_0 + - /language_model/layers.17/mlp/act_fn/Constant_1_output_0 + - /language_model/layers.17/mlp/act_fn/Constant_2_output_0 + - /language_model/layers.17/mlp/act_fn/Constant_3_output_0 + - /language_model/layers.18/mlp/act_fn/Constant_output_0 + - /language_model/layers.18/mlp/act_fn/Constant_1_output_0 + - /language_model/layers.18/mlp/act_fn/Constant_2_output_0 + - /language_model/layers.18/mlp/act_fn/Constant_3_output_0 + - /language_model/layers.19/mlp/act_fn/Constant_output_0 + - /language_model/layers.19/mlp/act_fn/Constant_1_output_0 + - /language_model/layers.19/mlp/act_fn/Constant_2_output_0 + - /language_model/layers.19/mlp/act_fn/Constant_3_output_0 + - /language_model/layers.20/mlp/act_fn/Constant_output_0 + - /language_model/layers.20/mlp/act_fn/Constant_1_output_0 + - /language_model/layers.20/mlp/act_fn/Constant_2_output_0 + - /language_model/layers.20/mlp/act_fn/Constant_3_output_0 + - /language_model/layers.21/mlp/act_fn/Constant_output_0 + - /language_model/layers.21/mlp/act_fn/Constant_1_output_0 + - /language_model/layers.21/mlp/act_fn/Constant_2_output_0 + - /language_model/layers.21/mlp/act_fn/Constant_3_output_0 + - /language_model/layers.22/mlp/act_fn/Constant_output_0 + - /language_model/layers.22/mlp/act_fn/Constant_1_output_0 + - /language_model/layers.22/mlp/act_fn/Constant_2_output_0 + - /language_model/layers.22/mlp/act_fn/Constant_3_output_0 + - /language_model/layers.23/mlp/act_fn/Constant_output_0 + - /language_model/layers.23/mlp/act_fn/Constant_1_output_0 + - /language_model/layers.23/mlp/act_fn/Constant_2_output_0 + - /language_model/layers.23/mlp/act_fn/Constant_3_output_0 + - /language_model/layers.24/mlp/act_fn/Constant_output_0 + - /language_model/layers.24/mlp/act_fn/Constant_1_output_0 + - /language_model/layers.24/mlp/act_fn/Constant_2_output_0 + - /language_model/layers.24/mlp/act_fn/Constant_3_output_0 + - /language_model/layers.25/mlp/act_fn/Constant_output_0 + - /language_model/layers.25/mlp/act_fn/Constant_1_output_0 + - /language_model/layers.25/mlp/act_fn/Constant_2_output_0 + - /language_model/layers.25/mlp/act_fn/Constant_3_output_0 + - /language_model/layers.26/mlp/act_fn/Constant_output_0 + - /language_model/layers.26/mlp/act_fn/Constant_1_output_0 + - /language_model/layers.26/mlp/act_fn/Constant_2_output_0 + - /language_model/layers.26/mlp/act_fn/Constant_3_output_0 + - /language_model/layers.27/mlp/act_fn/Constant_output_0 + - /language_model/layers.27/mlp/act_fn/Constant_1_output_0 + - /language_model/layers.27/mlp/act_fn/Constant_2_output_0 + - /language_model/layers.27/mlp/act_fn/Constant_3_output_0 + - /language_model/layers.28/mlp/act_fn/Constant_output_0 + - /language_model/layers.28/mlp/act_fn/Constant_1_output_0 + - /language_model/layers.28/mlp/act_fn/Constant_2_output_0 + - /language_model/layers.28/mlp/act_fn/Constant_3_output_0 + - /language_model/layers.29/mlp/act_fn/Constant_output_0 + - /language_model/layers.29/mlp/act_fn/Constant_1_output_0 + - /language_model/layers.29/mlp/act_fn/Constant_2_output_0 + - /language_model/layers.29/mlp/act_fn/Constant_3_output_0 + - /language_model/layers.30/mlp/act_fn/Constant_output_0 + - /language_model/layers.30/mlp/act_fn/Constant_1_output_0 + - /language_model/layers.30/mlp/act_fn/Constant_2_output_0 + - /language_model/layers.30/mlp/act_fn/Constant_3_output_0 + - /language_model/layers.31/mlp/act_fn/Constant_output_0 + - /language_model/layers.31/mlp/act_fn/Constant_1_output_0 + - /language_model/layers.31/mlp/act_fn/Constant_2_output_0 + - /language_model/layers.31/mlp/act_fn/Constant_3_output_0 + - /language_model/layers.32/mlp/act_fn/Constant_output_0 + - /language_model/layers.32/mlp/act_fn/Constant_1_output_0 + - /language_model/layers.32/mlp/act_fn/Constant_2_output_0 + - /language_model/layers.32/mlp/act_fn/Constant_3_output_0 + - /language_model/layers.33/mlp/act_fn/Constant_output_0 + - /language_model/layers.33/mlp/act_fn/Constant_1_output_0 + - /language_model/layers.33/mlp/act_fn/Constant_2_output_0 + - /language_model/layers.33/mlp/act_fn/Constant_3_output_0 + - /language_model/layers.34/mlp/act_fn/Constant_output_0 + - /language_model/layers.34/mlp/act_fn/Constant_1_output_0 + - /language_model/layers.34/mlp/act_fn/Constant_2_output_0 + - /language_model/layers.34/mlp/act_fn/Constant_3_output_0 + - /language_model/layers.35/mlp/act_fn/Constant_output_0 + - /language_model/layers.35/mlp/act_fn/Constant_1_output_0 + - /language_model/layers.35/mlp/act_fn/Constant_2_output_0 + - /language_model/layers.35/mlp/act_fn/Constant_3_output_0 + - /language_model/layers.36/mlp/act_fn/Constant_output_0 + - /language_model/layers.36/mlp/act_fn/Constant_1_output_0 + - /language_model/layers.36/mlp/act_fn/Constant_2_output_0 + - /language_model/layers.36/mlp/act_fn/Constant_3_output_0 + - /language_model/layers.37/mlp/act_fn/Constant_output_0 + - /language_model/layers.37/mlp/act_fn/Constant_1_output_0 + - /language_model/layers.37/mlp/act_fn/Constant_2_output_0 + - /language_model/layers.37/mlp/act_fn/Constant_3_output_0 + - /language_model/layers.38/mlp/act_fn/Constant_output_0 + - /language_model/layers.38/mlp/act_fn/Constant_1_output_0 + - /language_model/layers.38/mlp/act_fn/Constant_2_output_0 + - /language_model/layers.38/mlp/act_fn/Constant_3_output_0 + - /language_model/layers.39/mlp/act_fn/Constant_output_0 + - /language_model/layers.39/mlp/act_fn/Constant_1_output_0 + - /language_model/layers.39/mlp/act_fn/Constant_2_output_0 + - /language_model/layers.39/mlp/act_fn/Constant_3_output_0 + - /language_model/layers.40/mlp/act_fn/Constant_output_0 + - /language_model/layers.40/mlp/act_fn/Constant_1_output_0 + - /language_model/layers.40/mlp/act_fn/Constant_2_output_0 + - /language_model/layers.40/mlp/act_fn/Constant_3_output_0 + - /language_model/layers.41/mlp/act_fn/Constant_output_0 + - /language_model/layers.41/mlp/act_fn/Constant_1_output_0 + - /language_model/layers.41/mlp/act_fn/Constant_2_output_0 + - /language_model/layers.41/mlp/act_fn/Constant_3_output_0 + - /language_model/layers.42/mlp/act_fn/Constant_output_0 + - /language_model/layers.42/mlp/act_fn/Constant_1_output_0 + - /language_model/layers.42/mlp/act_fn/Constant_2_output_0 + - /language_model/layers.42/mlp/act_fn/Constant_3_output_0 + - /language_model/layers.43/mlp/act_fn/Constant_output_0 + - /language_model/layers.43/mlp/act_fn/Constant_1_output_0 + - /language_model/layers.43/mlp/act_fn/Constant_2_output_0 + - /language_model/layers.43/mlp/act_fn/Constant_3_output_0 + - /language_model/layers.44/mlp/act_fn/Constant_output_0 + - /language_model/layers.44/mlp/act_fn/Constant_1_output_0 + - /language_model/layers.44/mlp/act_fn/Constant_2_output_0 + - /language_model/layers.44/mlp/act_fn/Constant_3_output_0 + - /language_model/layers.45/mlp/act_fn/Constant_output_0 + - /language_model/layers.45/mlp/act_fn/Constant_1_output_0 + - /language_model/layers.45/mlp/act_fn/Constant_2_output_0 + - /language_model/layers.45/mlp/act_fn/Constant_3_output_0 + - /language_model/layers.46/mlp/act_fn/Constant_output_0 + - /language_model/layers.46/mlp/act_fn/Constant_1_output_0 + - /language_model/layers.46/mlp/act_fn/Constant_2_output_0 + - /language_model/layers.46/mlp/act_fn/Constant_3_output_0 + - /language_model/layers.47/mlp/act_fn/Constant_output_0 + - /language_model/layers.47/mlp/act_fn/Constant_1_output_0 + - /language_model/layers.47/mlp/act_fn/Constant_2_output_0 + - /language_model/layers.47/mlp/act_fn/Constant_3_output_0 + - /language_model/layers.48/mlp/act_fn/Constant_output_0 + - /language_model/layers.48/mlp/act_fn/Constant_1_output_0 + - /language_model/layers.48/mlp/act_fn/Constant_2_output_0 + - /language_model/layers.48/mlp/act_fn/Constant_3_output_0 + - /language_model/layers.49/mlp/act_fn/Constant_output_0 + - /language_model/layers.49/mlp/act_fn/Constant_1_output_0 + - /language_model/layers.49/mlp/act_fn/Constant_2_output_0 + - /language_model/layers.49/mlp/act_fn/Constant_3_output_0 + - /language_model/layers.50/mlp/act_fn/Constant_output_0 + - /language_model/layers.50/mlp/act_fn/Constant_1_output_0 + - /language_model/layers.50/mlp/act_fn/Constant_2_output_0 + - /language_model/layers.50/mlp/act_fn/Constant_3_output_0 + - /language_model/layers.51/mlp/act_fn/Constant_output_0 + - /language_model/layers.51/mlp/act_fn/Constant_1_output_0 + - /language_model/layers.51/mlp/act_fn/Constant_2_output_0 + - /language_model/layers.51/mlp/act_fn/Constant_3_output_0 + - /language_model/layers.52/mlp/act_fn/Constant_output_0 + - /language_model/layers.52/mlp/act_fn/Constant_1_output_0 + - /language_model/layers.52/mlp/act_fn/Constant_2_output_0 + - /language_model/layers.52/mlp/act_fn/Constant_3_output_0 + - /language_model/layers.53/mlp/act_fn/Constant_output_0 + - /language_model/layers.53/mlp/act_fn/Constant_1_output_0 + - /language_model/layers.53/mlp/act_fn/Constant_2_output_0 + - /language_model/layers.53/mlp/act_fn/Constant_3_output_0 + - /language_model/layers.54/mlp/act_fn/Constant_output_0 + - /language_model/layers.54/mlp/act_fn/Constant_1_output_0 + - /language_model/layers.54/mlp/act_fn/Constant_2_output_0 + - /language_model/layers.54/mlp/act_fn/Constant_3_output_0 + - /language_model/layers.55/mlp/act_fn/Constant_output_0 + - /language_model/layers.55/mlp/act_fn/Constant_1_output_0 + - /language_model/layers.55/mlp/act_fn/Constant_2_output_0 + - /language_model/layers.55/mlp/act_fn/Constant_3_output_0 + - /language_model/layers.56/mlp/act_fn/Constant_output_0 + - /language_model/layers.56/mlp/act_fn/Constant_1_output_0 + - /language_model/layers.56/mlp/act_fn/Constant_2_output_0 + - /language_model/layers.56/mlp/act_fn/Constant_3_output_0 + - /language_model/layers.57/mlp/act_fn/Constant_output_0 + - /language_model/layers.57/mlp/act_fn/Constant_1_output_0 + - /language_model/layers.57/mlp/act_fn/Constant_2_output_0 + - /language_model/layers.57/mlp/act_fn/Constant_3_output_0 + - /language_model/layers.58/mlp/act_fn/Constant_output_0 + - /language_model/layers.58/mlp/act_fn/Constant_1_output_0 + - /language_model/layers.58/mlp/act_fn/Constant_2_output_0 + - /language_model/layers.58/mlp/act_fn/Constant_3_output_0 + - /language_model/layers.59/mlp/act_fn/Constant_output_0 + - /language_model/layers.59/mlp/act_fn/Constant_1_output_0 + - /language_model/layers.59/mlp/act_fn/Constant_2_output_0 + - /language_model/layers.59/mlp/act_fn/Constant_3_output_0 + - /language_model/layers.60/mlp/act_fn/Constant_output_0 + - /language_model/layers.60/mlp/act_fn/Constant_1_output_0 + - /language_model/layers.60/mlp/act_fn/Constant_2_output_0 + - /language_model/layers.60/mlp/act_fn/Constant_3_output_0 + - /language_model/layers.61/mlp/act_fn/Constant_output_0 + - /language_model/layers.61/mlp/act_fn/Constant_1_output_0 + - /language_model/layers.61/mlp/act_fn/Constant_2_output_0 + - /language_model/layers.61/mlp/act_fn/Constant_3_output_0 + + #Add + - /language_model/layers.0/mlp/act_fn/Add_output_0 + - /language_model/layers.0/mlp/act_fn/Add_1_output_0 + - /language_model/layers.1/mlp/act_fn/Add_output_0 + - /language_model/layers.1/mlp/act_fn/Add_1_output_0 + - /language_model/layers.2/mlp/act_fn/Add_output_0 + - /language_model/layers.2/mlp/act_fn/Add_1_output_0 + - /language_model/layers.3/mlp/act_fn/Add_output_0 + - /language_model/layers.3/mlp/act_fn/Add_1_output_0 + - /language_model/layers.4/mlp/act_fn/Add_output_0 + - /language_model/layers.4/mlp/act_fn/Add_1_output_0 + - /language_model/layers.5/mlp/act_fn/Add_output_0 + - /language_model/layers.5/mlp/act_fn/Add_1_output_0 + - /language_model/layers.6/mlp/act_fn/Add_output_0 + - /language_model/layers.6/mlp/act_fn/Add_1_output_0 + - /language_model/layers.7/mlp/act_fn/Add_output_0 + - /language_model/layers.7/mlp/act_fn/Add_1_output_0 + - /language_model/layers.8/mlp/act_fn/Add_output_0 + - /language_model/layers.8/mlp/act_fn/Add_1_output_0 + - /language_model/layers.9/mlp/act_fn/Add_output_0 + - /language_model/layers.9/mlp/act_fn/Add_1_output_0 + - /language_model/layers.10/mlp/act_fn/Add_output_0 + - /language_model/layers.10/mlp/act_fn/Add_1_output_0 + - /language_model/layers.11/mlp/act_fn/Add_output_0 + - /language_model/layers.11/mlp/act_fn/Add_1_output_0 + - /language_model/layers.12/mlp/act_fn/Add_output_0 + - /language_model/layers.12/mlp/act_fn/Add_1_output_0 + - /language_model/layers.13/mlp/act_fn/Add_output_0 + - /language_model/layers.13/mlp/act_fn/Add_1_output_0 + - /language_model/layers.14/mlp/act_fn/Add_output_0 + - /language_model/layers.14/mlp/act_fn/Add_1_output_0 + - /language_model/layers.15/mlp/act_fn/Add_output_0 + - /language_model/layers.15/mlp/act_fn/Add_1_output_0 + - /language_model/layers.16/mlp/act_fn/Add_output_0 + - /language_model/layers.16/mlp/act_fn/Add_1_output_0 + - /language_model/layers.17/mlp/act_fn/Add_output_0 + - /language_model/layers.17/mlp/act_fn/Add_1_output_0 + - /language_model/layers.18/mlp/act_fn/Add_output_0 + - /language_model/layers.18/mlp/act_fn/Add_1_output_0 + - /language_model/layers.19/mlp/act_fn/Add_output_0 + - /language_model/layers.19/mlp/act_fn/Add_1_output_0 + - /language_model/layers.20/mlp/act_fn/Add_output_0 + - /language_model/layers.20/mlp/act_fn/Add_1_output_0 + - /language_model/layers.21/mlp/act_fn/Add_output_0 + - /language_model/layers.21/mlp/act_fn/Add_1_output_0 + - /language_model/layers.22/mlp/act_fn/Add_output_0 + - /language_model/layers.22/mlp/act_fn/Add_1_output_0 + - /language_model/layers.23/mlp/act_fn/Add_output_0 + - /language_model/layers.23/mlp/act_fn/Add_1_output_0 + - /language_model/layers.24/mlp/act_fn/Add_output_0 + - /language_model/layers.24/mlp/act_fn/Add_1_output_0 + - /language_model/layers.25/mlp/act_fn/Add_output_0 + - /language_model/layers.25/mlp/act_fn/Add_1_output_0 + - /language_model/layers.26/mlp/act_fn/Add_output_0 + - /language_model/layers.26/mlp/act_fn/Add_1_output_0 + - /language_model/layers.27/mlp/act_fn/Add_output_0 + - /language_model/layers.27/mlp/act_fn/Add_1_output_0 + - /language_model/layers.28/mlp/act_fn/Add_output_0 + - /language_model/layers.28/mlp/act_fn/Add_1_output_0 + - /language_model/layers.29/mlp/act_fn/Add_output_0 + - /language_model/layers.29/mlp/act_fn/Add_1_output_0 + - /language_model/layers.30/mlp/act_fn/Add_output_0 + - /language_model/layers.30/mlp/act_fn/Add_1_output_0 + - /language_model/layers.31/mlp/act_fn/Add_output_0 + - /language_model/layers.31/mlp/act_fn/Add_1_output_0 + - /language_model/layers.32/mlp/act_fn/Add_output_0 + - /language_model/layers.32/mlp/act_fn/Add_1_output_0 + - /language_model/layers.33/mlp/act_fn/Add_output_0 + - /language_model/layers.33/mlp/act_fn/Add_1_output_0 + - /language_model/layers.34/mlp/act_fn/Add_output_0 + - /language_model/layers.34/mlp/act_fn/Add_1_output_0 + - /language_model/layers.35/mlp/act_fn/Add_output_0 + - /language_model/layers.35/mlp/act_fn/Add_1_output_0 + - /language_model/layers.36/mlp/act_fn/Add_output_0 + - /language_model/layers.36/mlp/act_fn/Add_1_output_0 + - /language_model/layers.37/mlp/act_fn/Add_output_0 + - /language_model/layers.37/mlp/act_fn/Add_1_output_0 + - /language_model/layers.38/mlp/act_fn/Add_output_0 + - /language_model/layers.38/mlp/act_fn/Add_1_output_0 + - /language_model/layers.39/mlp/act_fn/Add_output_0 + - /language_model/layers.39/mlp/act_fn/Add_1_output_0 + - /language_model/layers.40/mlp/act_fn/Add_output_0 + - /language_model/layers.40/mlp/act_fn/Add_1_output_0 + - /language_model/layers.41/mlp/act_fn/Add_output_0 + - /language_model/layers.41/mlp/act_fn/Add_1_output_0 + - /language_model/layers.42/mlp/act_fn/Add_output_0 + - /language_model/layers.42/mlp/act_fn/Add_1_output_0 + - /language_model/layers.43/mlp/act_fn/Add_output_0 + - /language_model/layers.43/mlp/act_fn/Add_1_output_0 + - /language_model/layers.44/mlp/act_fn/Add_output_0 + - /language_model/layers.44/mlp/act_fn/Add_1_output_0 + - /language_model/layers.45/mlp/act_fn/Add_output_0 + - /language_model/layers.45/mlp/act_fn/Add_1_output_0 + - /language_model/layers.46/mlp/act_fn/Add_output_0 + - /language_model/layers.46/mlp/act_fn/Add_1_output_0 + - /language_model/layers.47/mlp/act_fn/Add_output_0 + - /language_model/layers.47/mlp/act_fn/Add_1_output_0 + - /language_model/layers.48/mlp/act_fn/Add_output_0 + - /language_model/layers.48/mlp/act_fn/Add_1_output_0 + - /language_model/layers.49/mlp/act_fn/Add_output_0 + - /language_model/layers.49/mlp/act_fn/Add_1_output_0 + - /language_model/layers.50/mlp/act_fn/Add_output_0 + - /language_model/layers.50/mlp/act_fn/Add_1_output_0 + - /language_model/layers.51/mlp/act_fn/Add_output_0 + - /language_model/layers.51/mlp/act_fn/Add_1_output_0 + - /language_model/layers.52/mlp/act_fn/Add_output_0 + - /language_model/layers.52/mlp/act_fn/Add_1_output_0 + - /language_model/layers.53/mlp/act_fn/Add_output_0 + - /language_model/layers.53/mlp/act_fn/Add_1_output_0 + - /language_model/layers.54/mlp/act_fn/Add_output_0 + - /language_model/layers.54/mlp/act_fn/Add_1_output_0 + - /language_model/layers.55/mlp/act_fn/Add_output_0 + - /language_model/layers.55/mlp/act_fn/Add_1_output_0 + - /language_model/layers.56/mlp/act_fn/Add_output_0 + - /language_model/layers.56/mlp/act_fn/Add_1_output_0 + - /language_model/layers.57/mlp/act_fn/Add_output_0 + - /language_model/layers.57/mlp/act_fn/Add_1_output_0 + - /language_model/layers.58/mlp/act_fn/Add_output_0 + - /language_model/layers.58/mlp/act_fn/Add_1_output_0 + - /language_model/layers.59/mlp/act_fn/Add_output_0 + - /language_model/layers.59/mlp/act_fn/Add_1_output_0 + - /language_model/layers.60/mlp/act_fn/Add_output_0 + - /language_model/layers.60/mlp/act_fn/Add_1_output_0 + - /language_model/layers.61/mlp/act_fn/Add_output_0 + - /language_model/layers.61/mlp/act_fn/Add_1_output_0 + + #Tanh + - /language_model/layers.0/mlp/act_fn/Tanh_output_0 + - /language_model/layers.1/mlp/act_fn/Tanh_output_0 + - /language_model/layers.2/mlp/act_fn/Tanh_output_0 + - /language_model/layers.3/mlp/act_fn/Tanh_output_0 + - /language_model/layers.4/mlp/act_fn/Tanh_output_0 + - /language_model/layers.5/mlp/act_fn/Tanh_output_0 + - /language_model/layers.6/mlp/act_fn/Tanh_output_0 + - /language_model/layers.7/mlp/act_fn/Tanh_output_0 + - /language_model/layers.8/mlp/act_fn/Tanh_output_0 + - /language_model/layers.9/mlp/act_fn/Tanh_output_0 + - /language_model/layers.10/mlp/act_fn/Tanh_output_0 + - /language_model/layers.11/mlp/act_fn/Tanh_output_0 + - /language_model/layers.12/mlp/act_fn/Tanh_output_0 + - /language_model/layers.13/mlp/act_fn/Tanh_output_0 + - /language_model/layers.14/mlp/act_fn/Tanh_output_0 + - /language_model/layers.15/mlp/act_fn/Tanh_output_0 + - /language_model/layers.16/mlp/act_fn/Tanh_output_0 + - /language_model/layers.17/mlp/act_fn/Tanh_output_0 + - /language_model/layers.18/mlp/act_fn/Tanh_output_0 + - /language_model/layers.19/mlp/act_fn/Tanh_output_0 + - /language_model/layers.20/mlp/act_fn/Tanh_output_0 + - /language_model/layers.21/mlp/act_fn/Tanh_output_0 + - /language_model/layers.22/mlp/act_fn/Tanh_output_0 + - /language_model/layers.23/mlp/act_fn/Tanh_output_0 + - /language_model/layers.24/mlp/act_fn/Tanh_output_0 + - /language_model/layers.25/mlp/act_fn/Tanh_output_0 + - /language_model/layers.26/mlp/act_fn/Tanh_output_0 + - /language_model/layers.27/mlp/act_fn/Tanh_output_0 + - /language_model/layers.28/mlp/act_fn/Tanh_output_0 + - /language_model/layers.29/mlp/act_fn/Tanh_output_0 + - /language_model/layers.30/mlp/act_fn/Tanh_output_0 + - /language_model/layers.31/mlp/act_fn/Tanh_output_0 + - /language_model/layers.32/mlp/act_fn/Tanh_output_0 + - /language_model/layers.33/mlp/act_fn/Tanh_output_0 + - /language_model/layers.34/mlp/act_fn/Tanh_output_0 + - /language_model/layers.35/mlp/act_fn/Tanh_output_0 + - /language_model/layers.36/mlp/act_fn/Tanh_output_0 + - /language_model/layers.37/mlp/act_fn/Tanh_output_0 + - /language_model/layers.38/mlp/act_fn/Tanh_output_0 + - /language_model/layers.39/mlp/act_fn/Tanh_output_0 + - /language_model/layers.40/mlp/act_fn/Tanh_output_0 + - /language_model/layers.41/mlp/act_fn/Tanh_output_0 + - /language_model/layers.42/mlp/act_fn/Tanh_output_0 + - /language_model/layers.43/mlp/act_fn/Tanh_output_0 + - /language_model/layers.44/mlp/act_fn/Tanh_output_0 + - /language_model/layers.45/mlp/act_fn/Tanh_output_0 + - /language_model/layers.46/mlp/act_fn/Tanh_output_0 + - /language_model/layers.47/mlp/act_fn/Tanh_output_0 + - /language_model/layers.48/mlp/act_fn/Tanh_output_0 + - /language_model/layers.49/mlp/act_fn/Tanh_output_0 + - /language_model/layers.50/mlp/act_fn/Tanh_output_0 + - /language_model/layers.51/mlp/act_fn/Tanh_output_0 + - /language_model/layers.52/mlp/act_fn/Tanh_output_0 + - /language_model/layers.53/mlp/act_fn/Tanh_output_0 + - /language_model/layers.54/mlp/act_fn/Tanh_output_0 + - /language_model/layers.55/mlp/act_fn/Tanh_output_0 + - /language_model/layers.56/mlp/act_fn/Tanh_output_0 + - /language_model/layers.57/mlp/act_fn/Tanh_output_0 + - /language_model/layers.58/mlp/act_fn/Tanh_output_0 + - /language_model/layers.59/mlp/act_fn/Tanh_output_0 + - /language_model/layers.60/mlp/act_fn/Tanh_output_0 + - /language_model/layers.61/mlp/act_fn/Tanh_output_0 + - /language_model/layers.0/mlp/Mul_output_0 + - /language_model/layers.1/mlp/Mul_output_0 + - /language_model/layers.2/mlp/Mul_output_0 + - /language_model/layers.3/mlp/Mul_output_0 + - /language_model/layers.4/mlp/Mul_output_0 + - /language_model/layers.5/mlp/Mul_output_0 + - /language_model/layers.6/mlp/Mul_output_0 + - /language_model/layers.7/mlp/Mul_output_0 + - /language_model/layers.8/mlp/Mul_output_0 + - /language_model/layers.9/mlp/Mul_output_0 + - /language_model/layers.10/mlp/Mul_output_0 + - /language_model/layers.11/mlp/Mul_output_0 + - /language_model/layers.12/mlp/Mul_output_0 + - /language_model/layers.13/mlp/Mul_output_0 + - /language_model/layers.14/mlp/Mul_output_0 + - /language_model/layers.15/mlp/Mul_output_0 + - /language_model/layers.16/mlp/Mul_output_0 + - /language_model/layers.17/mlp/Mul_output_0 + - /language_model/layers.18/mlp/Mul_output_0 + - /language_model/layers.19/mlp/Mul_output_0 + - /language_model/layers.20/mlp/Mul_output_0 + - /language_model/layers.21/mlp/Mul_output_0 + - /language_model/layers.22/mlp/Mul_output_0 + - /language_model/layers.23/mlp/Mul_output_0 + - /language_model/layers.24/mlp/Mul_output_0 + - /language_model/layers.25/mlp/Mul_output_0 + - /language_model/layers.26/mlp/Mul_output_0 + - /language_model/layers.27/mlp/Mul_output_0 + - /language_model/layers.28/mlp/Mul_output_0 + - /language_model/layers.29/mlp/Mul_output_0 + - /language_model/layers.30/mlp/Mul_output_0 + - /language_model/layers.31/mlp/Mul_output_0 + - /language_model/layers.32/mlp/Mul_output_0 + - /language_model/layers.33/mlp/Mul_output_0 + - /language_model/layers.34/mlp/Mul_output_0 + - /language_model/layers.35/mlp/Mul_output_0 + - /language_model/layers.36/mlp/Mul_output_0 + - /language_model/layers.37/mlp/Mul_output_0 + - /language_model/layers.38/mlp/Mul_output_0 + - /language_model/layers.39/mlp/Mul_output_0 + - /language_model/layers.40/mlp/Mul_output_0 + - /language_model/layers.41/mlp/Mul_output_0 + - /language_model/layers.42/mlp/Mul_output_0 + - /language_model/layers.43/mlp/Mul_output_0 + - /language_model/layers.44/mlp/Mul_output_0 + - /language_model/layers.45/mlp/Mul_output_0 + - /language_model/layers.46/mlp/Mul_output_0 + - /language_model/layers.47/mlp/Mul_output_0 + - /language_model/layers.48/mlp/Mul_output_0 + - /language_model/layers.49/mlp/Mul_output_0 + - /language_model/layers.50/mlp/Mul_output_0 + - /language_model/layers.51/mlp/Mul_output_0 + - /language_model/layers.52/mlp/Mul_output_0 + - /language_model/layers.53/mlp/Mul_output_0 + - /language_model/layers.54/mlp/Mul_output_0 + - /language_model/layers.55/mlp/Mul_output_0 + - /language_model/layers.56/mlp/Mul_output_0 + - /language_model/layers.57/mlp/Mul_output_0 + - /language_model/layers.58/mlp/Mul_output_0 + - /language_model/layers.59/mlp/Mul_output_0 + - /language_model/layers.60/mlp/Mul_output_0 + - /language_model/layers.61/mlp/Mul_output_0 + - /language_model/layers.0/Add_1_output_0 + - /language_model/layers.0/Add_2_output_0 + - /language_model/layers.0/Add_3_output_0 + - /language_model/layers.0/Add_output_0 + - /language_model/layers.1/Add_1_output_0 + - /language_model/layers.1/Add_2_output_0 + - /language_model/layers.1/Add_3_output_0 + - /language_model/layers.1/Add_output_0 + - /language_model/layers.2/Add_1_output_0 + - /language_model/layers.2/Add_2_output_0 + - /language_model/layers.2/Add_3_output_0 + - /language_model/layers.2/Add_output_0 + - /language_model/layers.3/Add_1_output_0 + - /language_model/layers.3/Add_2_output_0 + - /language_model/layers.3/Add_3_output_0 + - /language_model/layers.3/Add_output_0 + - /language_model/layers.4/Add_1_output_0 + - /language_model/layers.4/Add_2_output_0 + - /language_model/layers.4/Add_3_output_0 + - /language_model/layers.4/Add_output_0 + - /language_model/layers.5/Add_1_output_0 + - /language_model/layers.5/Add_2_output_0 + - /language_model/layers.5/Add_3_output_0 + - /language_model/layers.5/Add_output_0 + - /language_model/layers.6/Add_1_output_0 + - /language_model/layers.6/Add_2_output_0 + - /language_model/layers.6/Add_3_output_0 + - /language_model/layers.6/Add_output_0 + - /language_model/layers.7/Add_1_output_0 + - /language_model/layers.7/Add_2_output_0 + - /language_model/layers.7/Add_3_output_0 + - /language_model/layers.7/Add_output_0 + - /language_model/layers.8/Add_1_output_0 + - /language_model/layers.8/Add_2_output_0 + - /language_model/layers.8/Add_3_output_0 + - /language_model/layers.8/Add_output_0 + - /language_model/layers.9/Add_1_output_0 + - /language_model/layers.9/Add_2_output_0 + - /language_model/layers.9/Add_3_output_0 + - /language_model/layers.9/Add_output_0 + - /language_model/layers.10/Add_1_output_0 + - /language_model/layers.10/Add_2_output_0 + - /language_model/layers.10/Add_3_output_0 + - /language_model/layers.10/Add_output_0 + - /language_model/layers.11/Add_1_output_0 + - /language_model/layers.11/Add_2_output_0 + - /language_model/layers.11/Add_3_output_0 + - /language_model/layers.11/Add_output_0 + - /language_model/layers.12/Add_1_output_0 + - /language_model/layers.12/Add_2_output_0 + - /language_model/layers.12/Add_3_output_0 + - /language_model/layers.12/Add_output_0 + - /language_model/layers.13/Add_1_output_0 + - /language_model/layers.13/Add_2_output_0 + - /language_model/layers.13/Add_3_output_0 + - /language_model/layers.13/Add_output_0 + - /language_model/layers.14/Add_1_output_0 + - /language_model/layers.14/Add_2_output_0 + - /language_model/layers.14/Add_3_output_0 + - /language_model/layers.14/Add_output_0 + - /language_model/layers.15/Add_1_output_0 + - /language_model/layers.15/Add_2_output_0 + - /language_model/layers.15/Add_3_output_0 + - /language_model/layers.15/Add_output_0 + - /language_model/layers.16/Add_1_output_0 + - /language_model/layers.16/Add_2_output_0 + - /language_model/layers.16/Add_3_output_0 + - /language_model/layers.16/Add_output_0 + - /language_model/layers.17/Add_1_output_0 + - /language_model/layers.17/Add_2_output_0 + - /language_model/layers.17/Add_3_output_0 + - /language_model/layers.17/Add_output_0 + - /language_model/layers.18/Add_1_output_0 + - /language_model/layers.18/Add_2_output_0 + - /language_model/layers.18/Add_3_output_0 + - /language_model/layers.18/Add_output_0 + - /language_model/layers.19/Add_1_output_0 + - /language_model/layers.19/Add_2_output_0 + - /language_model/layers.19/Add_3_output_0 + - /language_model/layers.19/Add_output_0 + - /language_model/layers.20/Add_1_output_0 + - /language_model/layers.20/Add_2_output_0 + - /language_model/layers.20/Add_3_output_0 + - /language_model/layers.20/Add_output_0 + - /language_model/layers.21/Add_1_output_0 + - /language_model/layers.21/Add_2_output_0 + - /language_model/layers.21/Add_3_output_0 + - /language_model/layers.21/Add_output_0 + - /language_model/layers.22/Add_1_output_0 + - /language_model/layers.22/Add_2_output_0 + - /language_model/layers.22/Add_3_output_0 + - /language_model/layers.22/Add_output_0 + - /language_model/layers.23/Add_1_output_0 + - /language_model/layers.23/Add_2_output_0 + - /language_model/layers.23/Add_output_0 + - /language_model/layers.24/Add_1_output_0 + - /language_model/layers.24/Add_2_output_0 + - /language_model/layers.24/Add_3_output_0 + - /language_model/layers.24/Add_output_0 + - /language_model/layers.25/Add_1_output_0 + - /language_model/layers.25/Add_2_output_0 + - /language_model/layers.25/Add_3_output_0 + - /language_model/layers.25/Add_output_0 + - /language_model/layers.26/Add_1_output_0 + - /language_model/layers.26/Add_2_output_0 + - /language_model/layers.26/Add_3_output_0 + - /language_model/layers.26/Add_output_0 + - /language_model/layers.27/Add_1_output_0 + - /language_model/layers.27/Add_2_output_0 + - /language_model/layers.27/Add_3_output_0 + - /language_model/layers.27/Add_output_0 + - /language_model/layers.28/Add_1_output_0 + - /language_model/layers.28/Add_2_output_0 + - /language_model/layers.28/Add_3_output_0 + - /language_model/layers.28/Add_output_0 + - /language_model/layers.29/Add_1_output_0 + - /language_model/layers.29/Add_2_output_0 + - /language_model/layers.29/Add_3_output_0 + - /language_model/layers.29/Add_output_0 + - /language_model/layers.30/Add_1_output_0 + - /language_model/layers.30/Add_2_output_0 + - /language_model/layers.30/Add_3_output_0 + - /language_model/layers.30/Add_output_0 + - /language_model/layers.31/Add_1_output_0 + - /language_model/layers.31/Add_2_output_0 + - /language_model/layers.31/Add_3_output_0 + - /language_model/layers.31/Add_output_0 + - /language_model/layers.32/Add_1_output_0 + - /language_model/layers.32/Add_2_output_0 + - /language_model/layers.32/Add_3_output_0 + - /language_model/layers.32/Add_output_0 + - /language_model/layers.33/Add_1_output_0 + - /language_model/layers.33/Add_2_output_0 + - /language_model/layers.33/Add_3_output_0 + - /language_model/layers.33/Add_output_0 + - /language_model/layers.34/Add_1_output_0 + - /language_model/layers.34/Add_2_output_0 + - /language_model/layers.34/Add_3_output_0 + - /language_model/layers.34/Add_output_0 + - /language_model/layers.35/Add_1_output_0 + - /language_model/layers.35/Add_2_output_0 + - /language_model/layers.35/Add_3_output_0 + - /language_model/layers.35/Add_output_0 + - /language_model/layers.36/Add_1_output_0 + - /language_model/layers.36/Add_2_output_0 + - /language_model/layers.36/Add_3_output_0 + - /language_model/layers.36/Add_output_0 + - /language_model/layers.37/Add_1_output_0 + - /language_model/layers.37/Add_2_output_0 + - /language_model/layers.37/Add_3_output_0 + - /language_model/layers.37/Add_output_0 + - /language_model/layers.38/Add_1_output_0 + - /language_model/layers.38/Add_2_output_0 + - /language_model/layers.38/Add_3_output_0 + - /language_model/layers.38/Add_output_0 + - /language_model/layers.39/Add_1_output_0 + - /language_model/layers.39/Add_2_output_0 + - /language_model/layers.39/Add_3_output_0 + - /language_model/layers.39/Add_output_0 + - /language_model/layers.40/Add_1_output_0 + - /language_model/layers.40/Add_2_output_0 + - /language_model/layers.40/Add_3_output_0 + - /language_model/layers.40/Add_output_0 + - /language_model/layers.41/Add_1_output_0 + - /language_model/layers.41/Add_2_output_0 + - /language_model/layers.41/Add_3_output_0 + - /language_model/layers.41/Add_output_0 + - /language_model/layers.42/Add_1_output_0 + - /language_model/layers.42/Add_2_output_0 + - /language_model/layers.42/Add_3_output_0 + - /language_model/layers.42/Add_output_0 + - /language_model/layers.43/Add_1_output_0 + - /language_model/layers.43/Add_2_output_0 + - /language_model/layers.43/Add_3_output_0 + - /language_model/layers.43/Add_output_0 + - /language_model/layers.44/Add_1_output_0 + - /language_model/layers.44/Add_2_output_0 + - /language_model/layers.44/Add_3_output_0 + - /language_model/layers.44/Add_output_0 + - /language_model/layers.45/Add_1_output_0 + - /language_model/layers.45/Add_2_output_0 + - /language_model/layers.45/Add_3_output_0 + - /language_model/layers.45/Add_output_0 + - /language_model/layers.46/Add_1_output_0 + - /language_model/layers.46/Add_2_output_0 + - /language_model/layers.46/Add_3_output_0 + - /language_model/layers.46/Add_output_0 + - /language_model/layers.47/Add_1_output_0 + - /language_model/layers.47/Add_2_output_0 + - /language_model/layers.47/Add_3_output_0 + - /language_model/layers.47/Add_output_0 + - /language_model/layers.48/Add_1_output_0 + - /language_model/layers.48/Add_2_output_0 + - /language_model/layers.48/Add_3_output_0 + - /language_model/layers.48/Add_output_0 + - /language_model/layers.49/Add_1_output_0 + - /language_model/layers.49/Add_2_output_0 + - /language_model/layers.49/Add_3_output_0 + - /language_model/layers.49/Add_output_0 + - /language_model/layers.50/Add_1_output_0 + - /language_model/layers.50/Add_2_output_0 + - /language_model/layers.50/Add_3_output_0 + - /language_model/layers.50/Add_output_0 + - /language_model/layers.51/Add_1_output_0 + - /language_model/layers.51/Add_2_output_0 + - /language_model/layers.51/Add_3_output_0 + - /language_model/layers.51/Add_output_0 + - /language_model/layers.52/Add_1_output_0 + - /language_model/layers.52/Add_2_output_0 + - /language_model/layers.52/Add_3_output_0 + - /language_model/layers.52/Add_output_0 + - /language_model/layers.53/Add_1_output_0 + - /language_model/layers.53/Add_2_output_0 + - /language_model/layers.53/Add_3_output_0 + - /language_model/layers.53/Add_output_0 + - /language_model/layers.54/Add_1_output_0 + - /language_model/layers.54/Add_2_output_0 + - /language_model/layers.54/Add_3_output_0 + - /language_model/layers.54/Add_output_0 + - /language_model/layers.55/Add_1_output_0 + - /language_model/layers.55/Add_2_output_0 + - /language_model/layers.55/Add_3_output_0 + - /language_model/layers.55/Add_output_0 + - /language_model/layers.56/Add_1_output_0 + - /language_model/layers.56/Add_2_output_0 + - /language_model/layers.56/Add_3_output_0 + - /language_model/layers.56/Add_output_0 + - /language_model/layers.57/Add_1_output_0 + - /language_model/layers.57/Add_2_output_0 + - /language_model/layers.57/Add_3_output_0 + - /language_model/layers.57/Add_output_0 + - /language_model/layers.58/Add_1_output_0 + - /language_model/layers.58/Add_2_output_0 + - /language_model/layers.58/Add_3_output_0 + - /language_model/layers.58/Add_output_0 + - /language_model/layers.59/Add_1_output_0 + - /language_model/layers.59/Add_2_output_0 + - /language_model/layers.59/Add_3_output_0 + - /language_model/layers.59/Add_output_0 + - /language_model/layers.60/Add_1_output_0 + - /language_model/layers.60/Add_2_output_0 + - /language_model/layers.60/Add_3_output_0 + - /language_model/layers.60/Add_output_0 + - /language_model/layers.61/Add_1_output_0 + - /language_model/layers.61/Add_2_output_0 + - /language_model/layers.61/Add_3_output_0 + - /language_model/layers.61/Add_output_0 + - /language_model/norm/Add_output_0 + - /language_model/layers.0/self_attn/Mul_output_0 + - /language_model/layers.2/self_attn/Mul_output_0 + - /language_model/layers.3/self_attn/Mul_output_0 + - /language_model/layers.4/self_attn/Mul_output_0 + - /language_model/layers.5/self_attn/Mul_output_0 + - /language_model/layers.6/self_attn/Mul_output_0 + - /language_model/layers.7/self_attn/Mul_output_0 + - /language_model/layers.8/self_attn/Mul_output_0 + - /language_model/layers.9/self_attn/Mul_output_0 + - /language_model/layers.10/self_attn/Mul_output_0 + - /language_model/layers.11/self_attn/Mul_output_0 + - /language_model/layers.12/self_attn/Mul_output_0 + - /language_model/layers.13/self_attn/Mul_output_0 + - /language_model/layers.14/self_attn/Mul_output_0 + - /language_model/layers.15/self_attn/Mul_output_0 + - /language_model/layers.16/self_attn/Mul_output_0 + - /language_model/layers.17/self_attn/Mul_output_0 + - /language_model/layers.18/self_attn/Mul_output_0 + - /language_model/layers.19/self_attn/Mul_output_0 + - /language_model/layers.20/self_attn/Mul_output_0 + - /language_model/layers.21/self_attn/Mul_output_0 + - /language_model/layers.22/self_attn/Mul_output_0 + - /language_model/layers.23/self_attn/Mul_output_0 + - /language_model/layers.24/self_attn/Mul_output_0 + - /language_model/layers.25/self_attn/Mul_output_0 + - /language_model/layers.26/self_attn/Mul_output_0 + - /language_model/layers.27/self_attn/Mul_output_0 + - /language_model/layers.28/self_attn/Mul_output_0 + - /language_model/layers.29/self_attn/Mul_output_0 + - /language_model/layers.30/self_attn/Mul_output_0 + - /language_model/layers.31/self_attn/Mul_output_0 + - /language_model/layers.32/self_attn/Mul_output_0 + - /language_model/layers.33/self_attn/Mul_output_0 + - /language_model/layers.34/self_attn/Mul_output_0 + - /language_model/layers.35/self_attn/Mul_output_0 + - /language_model/layers.36/self_attn/Mul_output_0 + - /language_model/layers.37/self_attn/Mul_output_0 + - /language_model/layers.38/self_attn/Mul_output_0 + - /language_model/layers.39/self_attn/Mul_output_0 + - /language_model/layers.40/self_attn/Mul_output_0 + - /language_model/layers.41/self_attn/Mul_output_0 + - /language_model/layers.42/self_attn/Mul_output_0 + - /language_model/layers.43/self_attn/Mul_output_0 + - /language_model/layers.44/self_attn/Mul_output_0 + - /language_model/layers.45/self_attn/Mul_output_0 + - /language_model/layers.46/self_attn/Mul_output_0 + - /language_model/layers.47/self_attn/Mul_output_0 + - /language_model/layers.48/self_attn/Mul_output_0 + - /language_model/layers.49/self_attn/Mul_output_0 + - /language_model/layers.50/self_attn/Mul_output_0 + - /language_model/layers.51/self_attn/Mul_output_0 + - /language_model/layers.52/self_attn/Mul_output_0 + - /language_model/layers.53/self_attn/Mul_output_0 + - /language_model/layers.54/self_attn/Mul_output_0 + - /language_model/layers.55/self_attn/Mul_output_0 + - /language_model/layers.56/self_attn/Mul_output_0 + - /language_model/layers.57/self_attn/Mul_output_0 + - /language_model/layers.58/self_attn/Mul_output_0 + - /language_model/layers.59/self_attn/Mul_output_0 + - /language_model/layers.60/self_attn/Mul_output_0 + - /language_model/layers.61/self_attn/Mul_output_0 + - /language_model/layers.0/input_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.0/post_attention_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.0/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.0/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.0/self_attn/k_norm/CustomRMSNorm_output_0 + - /language_model/layers.0/self_attn/q_norm/CustomRMSNorm_output_0 + - /language_model/layers.1/input_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.1/post_attention_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.1/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.1/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.1/self_attn/k_norm/CustomRMSNorm_output_0 + - /language_model/layers.1/self_attn/q_norm/CustomRMSNorm_output_0 + - /language_model/layers.2/input_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.2/post_attention_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.2/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.2/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.2/self_attn/k_norm/CustomRMSNorm_output_0 + - /language_model/layers.2/self_attn/q_norm/CustomRMSNorm_output_0 + - /language_model/layers.3/input_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.3/post_attention_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.3/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.3/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.3/self_attn/k_norm/CustomRMSNorm_output_0 + - /language_model/layers.3/self_attn/q_norm/CustomRMSNorm_output_0 + - /language_model/layers.4/input_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.4/post_attention_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.4/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.4/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.4/self_attn/k_norm/CustomRMSNorm_output_0 + - /language_model/layers.4/self_attn/q_norm/CustomRMSNorm_output_0 + - /language_model/layers.5/input_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.5/post_attention_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.5/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.5/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.5/self_attn/k_norm/CustomRMSNorm_output_0 + - /language_model/layers.5/self_attn/q_norm/CustomRMSNorm_output_0 + - /language_model/layers.6/input_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.6/post_attention_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.6/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.6/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.6/self_attn/k_norm/CustomRMSNorm_output_0 + - /language_model/layers.6/self_attn/q_norm/CustomRMSNorm_output_0 + - /language_model/layers.7/input_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.7/post_attention_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.7/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.7/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.7/self_attn/k_norm/CustomRMSNorm_output_0 + - /language_model/layers.7/self_attn/q_norm/CustomRMSNorm_output_0 + - /language_model/layers.8/input_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.8/post_attention_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.8/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.8/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.8/self_attn/k_norm/CustomRMSNorm_output_0 + - /language_model/layers.8/self_attn/q_norm/CustomRMSNorm_output_0 + - /language_model/layers.9/input_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.9/post_attention_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.9/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.9/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.9/self_attn/k_norm/CustomRMSNorm_output_0 + - /language_model/layers.9/self_attn/q_norm/CustomRMSNorm_output_0 + - /language_model/layers.10/input_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.10/post_attention_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.10/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.10/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.10/self_attn/k_norm/CustomRMSNorm_output_0 + - /language_model/layers.10/self_attn/q_norm/CustomRMSNorm_output_0 + - /language_model/layers.11/input_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.11/post_attention_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.11/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.11/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.11/self_attn/k_norm/CustomRMSNorm_output_0 + - /language_model/layers.11/self_attn/q_norm/CustomRMSNorm_output_0 + - /language_model/layers.12/input_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.12/post_attention_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.12/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.12/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.12/self_attn/k_norm/CustomRMSNorm_output_0 + - /language_model/layers.12/self_attn/q_norm/CustomRMSNorm_output_0 + - /language_model/layers.13/input_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.13/post_attention_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.13/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.13/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.13/self_attn/k_norm/CustomRMSNorm_output_0 + - /language_model/layers.13/self_attn/q_norm/CustomRMSNorm_output_0 + - /language_model/layers.14/input_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.14/post_attention_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.14/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.14/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.14/self_attn/k_norm/CustomRMSNorm_output_0 + - /language_model/layers.14/self_attn/q_norm/CustomRMSNorm_output_0 + - /language_model/layers.15/input_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.15/post_attention_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.15/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.15/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.15/self_attn/k_norm/CustomRMSNorm_output_0 + - /language_model/layers.15/self_attn/q_norm/CustomRMSNorm_output_0 + - /language_model/layers.16/input_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.16/post_attention_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.16/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.16/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.16/self_attn/k_norm/CustomRMSNorm_output_0 + - /language_model/layers.16/self_attn/q_norm/CustomRMSNorm_output_0 + - /language_model/layers.17/input_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.17/post_attention_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.17/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.17/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.17/self_attn/k_norm/CustomRMSNorm_output_0 + - /language_model/layers.17/self_attn/q_norm/CustomRMSNorm_output_0 + - /language_model/layers.18/input_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.18/post_attention_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.18/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.18/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.18/self_attn/k_norm/CustomRMSNorm_output_0 + - /language_model/layers.18/self_attn/q_norm/CustomRMSNorm_output_0 + - /language_model/layers.19/input_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.19/post_attention_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.19/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.19/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.19/self_attn/k_norm/CustomRMSNorm_output_0 + - /language_model/layers.19/self_attn/q_norm/CustomRMSNorm_output_0 + - /language_model/layers.20/input_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.20/post_attention_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.20/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.20/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.20/self_attn/k_norm/CustomRMSNorm_output_0 + - /language_model/layers.20/self_attn/q_norm/CustomRMSNorm_output_0 + - /language_model/layers.21/input_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.21/post_attention_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.21/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.21/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.21/self_attn/k_norm/CustomRMSNorm_output_0 + - /language_model/layers.21/self_attn/q_norm/CustomRMSNorm_output_0 + - /language_model/layers.22/input_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.22/post_attention_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.22/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.22/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.22/self_attn/k_norm/CustomRMSNorm_output_0 + - /language_model/layers.22/self_attn/q_norm/CustomRMSNorm_output_0 + - /language_model/layers.23/input_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.23/post_attention_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.23/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.23/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.23/self_attn/k_norm/CustomRMSNorm_output_0 + - /language_model/layers.23/self_attn/q_norm/CustomRMSNorm_output_0 + - /language_model/layers.24/input_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.24/post_attention_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.24/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.24/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.24/self_attn/k_norm/CustomRMSNorm_output_0 + - /language_model/layers.24/self_attn/q_norm/CustomRMSNorm_output_0 + - /language_model/layers.25/input_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.25/post_attention_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.25/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.25/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.25/self_attn/k_norm/CustomRMSNorm_output_0 + - /language_model/layers.25/self_attn/q_norm/CustomRMSNorm_output_0 + - /language_model/layers.26/input_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.26/post_attention_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.26/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.26/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.26/self_attn/k_norm/CustomRMSNorm_output_0 + - /language_model/layers.26/self_attn/q_norm/CustomRMSNorm_output_0 + - /language_model/layers.27/input_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.27/post_attention_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.27/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.27/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.27/self_attn/k_norm/CustomRMSNorm_output_0 + - /language_model/layers.27/self_attn/q_norm/CustomRMSNorm_output_0 + - /language_model/layers.28/input_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.28/post_attention_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.28/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.28/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.28/self_attn/k_norm/CustomRMSNorm_output_0 + - /language_model/layers.28/self_attn/q_norm/CustomRMSNorm_output_0 + - /language_model/layers.29/input_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.29/post_attention_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.29/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.29/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.29/self_attn/k_norm/CustomRMSNorm_output_0 + - /language_model/layers.29/self_attn/q_norm/CustomRMSNorm_output_0 + - /language_model/layers.30/input_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.30/post_attention_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.30/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.30/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.30/self_attn/k_norm/CustomRMSNorm_output_0 + - /language_model/layers.30/self_attn/q_norm/CustomRMSNorm_output_0 + - /language_model/layers.31/input_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.31/post_attention_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.31/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.31/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.31/self_attn/k_norm/CustomRMSNorm_output_0 + - /language_model/layers.31/self_attn/q_norm/CustomRMSNorm_output_0 + - /language_model/layers.32/input_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.32/post_attention_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.32/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.32/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.32/self_attn/k_norm/CustomRMSNorm_output_0 + - /language_model/layers.32/self_attn/q_norm/CustomRMSNorm_output_0 + - /language_model/layers.33/input_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.33/post_attention_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.33/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.33/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.33/self_attn/k_norm/CustomRMSNorm_output_0 + - /language_model/layers.33/self_attn/q_norm/CustomRMSNorm_output_0 + - /language_model/layers.34/input_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.34/post_attention_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.34/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.34/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.34/self_attn/k_norm/CustomRMSNorm_output_0 + - /language_model/layers.34/self_attn/q_norm/CustomRMSNorm_output_0 + - /language_model/layers.35/input_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.35/post_attention_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.35/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.35/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.35/self_attn/k_norm/CustomRMSNorm_output_0 + - /language_model/layers.35/self_attn/q_norm/CustomRMSNorm_output_0 + - /language_model/layers.36/input_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.36/post_attention_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.36/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.36/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.36/self_attn/k_norm/CustomRMSNorm_output_0 + - /language_model/layers.36/self_attn/q_norm/CustomRMSNorm_output_0 + - /language_model/layers.37/input_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.37/post_attention_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.37/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.37/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.37/self_attn/k_norm/CustomRMSNorm_output_0 + - /language_model/layers.37/self_attn/q_norm/CustomRMSNorm_output_0 + - /language_model/layers.38/input_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.38/post_attention_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.38/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.38/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.38/self_attn/k_norm/CustomRMSNorm_output_0 + - /language_model/layers.38/self_attn/q_norm/CustomRMSNorm_output_0 + - /language_model/layers.39/input_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.39/post_attention_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.39/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.39/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.39/self_attn/k_norm/CustomRMSNorm_output_0 + - /language_model/layers.39/self_attn/q_norm/CustomRMSNorm_output_0 + - /language_model/layers.40/input_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.40/post_attention_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.40/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.40/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.40/self_attn/k_norm/CustomRMSNorm_output_0 + - /language_model/layers.40/self_attn/q_norm/CustomRMSNorm_output_0 + - /language_model/layers.41/input_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.41/post_attention_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.41/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.41/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.41/self_attn/k_norm/CustomRMSNorm_output_0 + - /language_model/layers.41/self_attn/q_norm/CustomRMSNorm_output_0 + - /language_model/layers.42/input_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.42/post_attention_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.42/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.42/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.42/self_attn/k_norm/CustomRMSNorm_output_0 + - /language_model/layers.42/self_attn/q_norm/CustomRMSNorm_output_0 + - /language_model/layers.43/input_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.43/post_attention_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.43/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.43/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.43/self_attn/k_norm/CustomRMSNorm_output_0 + - /language_model/layers.43/self_attn/q_norm/CustomRMSNorm_output_0 + - /language_model/layers.44/input_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.44/post_attention_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.44/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.44/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.44/self_attn/k_norm/CustomRMSNorm_output_0 + - /language_model/layers.44/self_attn/q_norm/CustomRMSNorm_output_0 + - /language_model/layers.45/input_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.45/post_attention_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.45/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.45/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.45/self_attn/k_norm/CustomRMSNorm_output_0 + - /language_model/layers.45/self_attn/q_norm/CustomRMSNorm_output_0 + - /language_model/layers.46/input_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.46/post_attention_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.46/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.46/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.46/self_attn/k_norm/CustomRMSNorm_output_0 + - /language_model/layers.46/self_attn/q_norm/CustomRMSNorm_output_0 + - /language_model/layers.47/input_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.47/post_attention_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.47/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.47/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.47/self_attn/k_norm/CustomRMSNorm_output_0 + - /language_model/layers.47/self_attn/q_norm/CustomRMSNorm_output_0 + - /language_model/layers.48/input_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.48/post_attention_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.48/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.48/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.48/self_attn/k_norm/CustomRMSNorm_output_0 + - /language_model/layers.48/self_attn/q_norm/CustomRMSNorm_output_0 + - /language_model/layers.49/input_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.49/post_attention_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.49/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.49/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.49/self_attn/k_norm/CustomRMSNorm_output_0 + - /language_model/layers.49/self_attn/q_norm/CustomRMSNorm_output_0 + - /language_model/layers.50/input_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.50/post_attention_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.50/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.50/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.50/self_attn/k_norm/CustomRMSNorm_output_0 + - /language_model/layers.50/self_attn/q_norm/CustomRMSNorm_output_0 + - /language_model/layers.51/input_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.51/post_attention_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.51/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.51/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.51/self_attn/k_norm/CustomRMSNorm_output_0 + - /language_model/layers.51/self_attn/q_norm/CustomRMSNorm_output_0 + - /language_model/layers.52/input_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.52/post_attention_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.52/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.52/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.52/self_attn/k_norm/CustomRMSNorm_output_0 + - /language_model/layers.52/self_attn/q_norm/CustomRMSNorm_output_0 + - /language_model/layers.53/input_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.53/post_attention_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.53/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.53/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.53/self_attn/k_norm/CustomRMSNorm_output_0 + - /language_model/layers.53/self_attn/q_norm/CustomRMSNorm_output_0 + - /language_model/layers.54/input_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.54/post_attention_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.54/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.54/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.54/self_attn/k_norm/CustomRMSNorm_output_0 + - /language_model/layers.54/self_attn/q_norm/CustomRMSNorm_output_0 + - /language_model/layers.55/input_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.55/post_attention_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.55/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.55/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.55/self_attn/k_norm/CustomRMSNorm_output_0 + - /language_model/layers.55/self_attn/q_norm/CustomRMSNorm_output_0 + - /language_model/layers.56/input_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.56/post_attention_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.56/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.56/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.56/self_attn/k_norm/CustomRMSNorm_output_0 + - /language_model/layers.56/self_attn/q_norm/CustomRMSNorm_output_0 + - /language_model/layers.57/input_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.57/post_attention_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.57/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.57/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.57/self_attn/k_norm/CustomRMSNorm_output_0 + - /language_model/layers.57/self_attn/q_norm/CustomRMSNorm_output_0 + - /language_model/layers.58/input_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.58/post_attention_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.58/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.58/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.58/self_attn/k_norm/CustomRMSNorm_output_0 + - /language_model/layers.58/self_attn/q_norm/CustomRMSNorm_output_0 + - /language_model/layers.59/input_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.59/post_attention_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.59/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.59/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.59/self_attn/k_norm/CustomRMSNorm_output_0 + - /language_model/layers.59/self_attn/q_norm/CustomRMSNorm_output_0 + - /language_model/layers.60/input_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.60/post_attention_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.60/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.60/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.60/self_attn/k_norm/CustomRMSNorm_output_0 + - /language_model/layers.60/self_attn/q_norm/CustomRMSNorm_output_0 + - /language_model/layers.61/input_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.61/post_attention_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.61/post_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.61/pre_feedforward_layernorm/CustomRMSNorm_output_0 + - /language_model/layers.61/self_attn/k_norm/CustomRMSNorm_output_0 + - /language_model/layers.61/self_attn/q_norm/CustomRMSNorm_output_0 + - /language_model/norm/CustomRMSNorm_output_0 diff --git a/examples/performance/compute_context_length/gpt_oss_disagg_mode_with_chunking.py b/examples/performance/compute_context_length/gpt_oss_disagg_mode_with_chunking.py new file mode 100644 index 000000000..50f513670 --- /dev/null +++ b/examples/performance/compute_context_length/gpt_oss_disagg_mode_with_chunking.py @@ -0,0 +1,190 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +# SPDX-License-Identifier: BSD-3-Clause +# +# ----------------------------------------------------------------------------- + +import os +import time + +import numpy as np +import torch +from transformers import AutoConfig, AutoTokenizer + +from QEfficient import QEFFAutoModelForCausalLM +from QEfficient.generation.cloud_infer import QAICInferenceSession + +dir_path = os.path.dirname(os.path.realpath(__file__)) +# subfunc_npi_file_path = os.path.join(dir_path, "subfunction_120b_npi.yaml") +# non_subfunc_npi_file_path = os.path.join(dir_path, "non_subfunction_120b_npi.yaml") + +model_id = "openai/gpt-oss-20b" # weights are not required to convert to fp32 + +prompt = """ +Once upon a time, in a small town, there lived a young boy named Alex. Alex was a curious and adventurous child, always eager to explore the world around him. One day, while playing in the park, Alex stumbled upon a mysterious old book hidden beneath a pile of leaves. The book was filled with stories of distant lands, magical creatures, and extraordinary adventures. + +As Alex flipped through the pages, he discovered a map that led to a hidden treasure. Excited by the prospect of a real-life treasure hunt, Alex decided to embark on a thrilling journey. He packed his backpack with snacks, a flashlight, and a compass, and set off into the unknown. + +The path to the treasure was not an easy one. Alex had to navigate through dense forests, cross rickety bridges, and solve riddles that guarded the treasure's location. +""" +# Run prefill +config = AutoConfig.from_pretrained(model_id) +tokenizer = AutoTokenizer.from_pretrained(model_id) +PREFILL_SEQ_LEN = 128 +CTX_LEN = 4096 + +qeff_model = QEFFAutoModelForCausalLM.from_pretrained( + model_id, + qaic_config={ + "ccl_enabled": True, + }, +) + +comp_ctx_lengths_decode = [1024, 2048, 4096] + +decode_qpc_path = qeff_model.compile( + prefill_seq_len=1, + ctx_len=CTX_LEN, + num_cores=16, + mxfp6_matmul=True, + mxint8_kv_cache=True, + num_devices=1, + mos=1, + aic_enable_depth_first=True, + num_speculative_tokens=None, + offload_pt_weights=False, # Need the weights in memory for prefill-model export/compilation in the next step + retain_full_kv=True, + prefill_only=False, + comp_ctx_lengths_decode=comp_ctx_lengths_decode, + # # split_retained_state_io=True, # This should be used for disagg serving via VLLM + # node_precision_info=non_subfunc_npi_file_path, +) + + +qeff_model1 = QEFFAutoModelForCausalLM.from_pretrained(model_id) + +# Following command errors out by default, the user is supposed to run the printed command and provide the generated qpc path as prefill_qpc_path commenting out lines 55-68 +# prefill_qpc_path = "provide path here" +prefill_qpc_path = qeff_model1.compile( + prefill_seq_len=PREFILL_SEQ_LEN, + ctx_len=CTX_LEN, + num_cores=16, + mxfp6_matmul=True, + mxint8_kv_cache=True, + num_devices=1, + mos=1, + aic_enable_depth_first=True, + num_speculative_tokens=None, + prefill_only=True, + enable_chunking=True, + use_onnx_subfunctions=True, + # # split_retained_state_io=True, # This should be used for disagg serving via VLLM + # node_precision_info=subfunc_npi_file_path, +) + + +inputs = tokenizer(prompt, return_tensors="np", padding=True) +position_ids = inputs["attention_mask"].sum(1, keepdims=True) +generation_len = 100 # CTX_LEN - position_ids.max() +padded_len = inputs["input_ids"].shape[1] +num_chunks = -(padded_len // -PREFILL_SEQ_LEN) # ceil divide without float +padded_len = num_chunks * PREFILL_SEQ_LEN # Convert to a multiple of prompt_len +inputs = tokenizer(prompt, return_tensors="np", padding="max_length", max_length=padded_len) +inputs["position_ids"] = np.where(inputs.pop("attention_mask"), np.arange(padded_len), -1) +inputs.pop("token_type_ids", None) +inputs = {k: torch.from_numpy(v) for k, v in inputs.items()} +inputs.pop("past_key_values", None) +inputs = {k: v.detach().numpy() for k, v in inputs.items()} + + +decode_session = QAICInferenceSession(decode_qpc_path) +prefill_session = QAICInferenceSession(prefill_qpc_path) + +all_outputs = [] + +for i in range(num_chunks): + chunk_inputs = inputs.copy() + chunk_inputs["input_ids"] = inputs["input_ids"][:, i * PREFILL_SEQ_LEN : (i + 1) * PREFILL_SEQ_LEN] + chunk_inputs["position_ids"] = inputs["position_ids"][:, i * PREFILL_SEQ_LEN : (i + 1) * PREFILL_SEQ_LEN] + + ins = time.time() + qpc_out = prefill_session.run(chunk_inputs) + + print(f"time for this run={time.time() - ins}") + for i in range(config.num_hidden_layers): + inputs[f"past_key.{i}"] = qpc_out[f"past_key.{i}_RetainedState"] + inputs[f"past_value.{i}"] = qpc_out[f"past_value.{i}_RetainedState"] + +all_outputs.append(np.argmax(qpc_out["logits"])) + + +def initialize_ccl(decode_inputs, comp_ctx_lengths_decode): + list_of_comp_ctx_lengths_decode = [np.zeros(length, dtype=np.int8) for length in comp_ctx_lengths_decode] + max_ccl_id = len(comp_ctx_lengths_decode) - 1 + max_position_id = np.max(decode_inputs["position_ids"]) + ccl_id_initial = 0 + ccl_id = ccl_id_initial + for i in range(ccl_id_initial, len(comp_ctx_lengths_decode)): + if max_position_id < comp_ctx_lengths_decode[i]: + ccl_id = i + break + + return ccl_id, max_ccl_id, list_of_comp_ctx_lengths_decode + + +decode_inputs = { + "input_ids": np.argmax(qpc_out["logits"]).reshape(1, 1), + "position_ids": np.max(inputs["position_ids"]).reshape(1, 1) + 1, +} +for i in range(config.num_hidden_layers): + decode_inputs[f"past_key.{i}"] = qpc_out[f"past_key.{i}_RetainedState"] + decode_inputs[f"past_value.{i}"] = qpc_out[f"past_value.{i}_RetainedState"] + +if comp_ctx_lengths_decode is not None: + ccl_id, max_ccl_id, list_of_comp_ctx_lengths_decode = initialize_ccl(decode_inputs, comp_ctx_lengths_decode) + decode_inputs["comp_ctx_lengths"] = list_of_comp_ctx_lengths_decode[ccl_id] + +st = time.time() +decode_out = decode_session.run(decode_inputs) +print(f"time for first run of decode with KV as input = {time.time() - st} sec\n") +all_outputs.append(np.argmax(decode_out["logits"])) +pos_id = np.max(decode_inputs["position_ids"]).reshape(1, 1) + 1 +loop_decode_inputs = { + "input_ids": np.argmax(decode_out["logits"]).reshape(1, 1), + "position_ids": pos_id, +} + +for i in range(config.num_hidden_layers): + loop_decode_inputs[f"past_key.{i}"] = decode_out[f"past_key.{i}_RetainedState"] + loop_decode_inputs[f"past_value.{i}"] = decode_out[f"past_value.{i}_RetainedState"] + +if comp_ctx_lengths_decode is not None: + ccl_id, max_ccl_id, list_of_comp_ctx_lengths_decode = initialize_ccl(loop_decode_inputs, comp_ctx_lengths_decode) + loop_decode_inputs["comp_ctx_lengths"] = list_of_comp_ctx_lengths_decode[ccl_id] + +st = time.time() +for i in range(generation_len - 2): + if comp_ctx_lengths_decode is not None: + # Update ccl_id and comp_ctx_lengths_decode based on the maximum position id + if pos_id >= comp_ctx_lengths_decode[ccl_id] - 1: + ccl_id = min(ccl_id + 1, max_ccl_id) + loop_decode_inputs["comp_ctx_lengths"] = list_of_comp_ctx_lengths_decode[ccl_id] + + decode_out = decode_session.run(loop_decode_inputs) + all_outputs.append(np.argmax(decode_out["logits"])) + pos_id += 1 + for i in range(config.num_hidden_layers): + loop_decode_inputs[f"past_key.{i}"] = decode_out[f"past_key.{i}_RetainedState"] + loop_decode_inputs[f"past_value.{i}"] = decode_out[f"past_value.{i}_RetainedState"] + + loop_decode_inputs.update( + { + "input_ids": np.argmax(decode_out["logits"]).reshape(1, 1), + "position_ids": pos_id, + } + ) +ft = time.time() + +print(f"decode tok/sec={(generation_len - 2) / (ft - st)}") +print(f"input\n{prompt}\noutput\n{tokenizer.decode(all_outputs)}") diff --git a/examples/text_generation/README.md b/examples/text_generation/README.md index 2d8754768..5e40b79e1 100644 --- a/examples/text_generation/README.md +++ b/examples/text_generation/README.md @@ -115,6 +115,7 @@ This example: - Demonstrates MoE model inference - Uses sparse expert activation for efficiency - Works with Qwen, Mixtral, and other MoE models +- Supports explicit ONNX subfunction enablement with `--use-onnx-subfunctions` ## CLI Workflow @@ -216,6 +217,7 @@ This uses the pre-compiled QPC for fast inference. You can run this multiple tim | `--device_group` | Device IDs to use | `[0]` | `[0]` or `[0,1,2,3]` | | `--mxfp6` | Enable MXFP6 quantization | False | Add flag to enable | | `--mxint8_kv_cache` | Enable MXINT8 KV cache | False | Add flag to enable | +| `--use-onnx-subfunctions` | Enable ONNX subfunctions for export/compile | False | Add flag to enable | | `--mos` | Memory optimization strategy | 1 | `1` or `2` | | `--aic_enable_depth_first` | Enable depth-first execution | False | Add flag to enable | @@ -312,4 +314,3 @@ This script demonstrates: By default, exported models and QPC files are stored in `~/.cache/qeff_cache`. Customize this with: - `QEFF_HOME`: Primary cache directory - `XDG_CACHE_HOME`: Alternative cache location - diff --git a/pyproject.toml b/pyproject.toml index f38bcc17d..8c0036a37 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -26,8 +26,6 @@ dependencies = [ "peft==0.17.0", "datasets==2.20.0", "fsspec==2023.6.0", - "multidict==6.0.4", - "urllib3<2", "sentencepiece==0.2.0", "onnx==1.18.0", "onnxruntime==1.22", diff --git a/scripts/debug/README.md b/scripts/debug/README.md new file mode 100644 index 000000000..8da3a4c14 --- /dev/null +++ b/scripts/debug/README.md @@ -0,0 +1,150 @@ +# Proxy Models Examples + +## Overview + +This directory contains examples demonstrating how to enable and use **proxy models** in QEfficient. Proxy models replace specific layers (embeddings and LM heads) with dummy layers, enabling efficient model export and IO file generation for downstream optimization and validation. + +## What is a Proxy Model? + +A proxy model is a modified version of a transformer model where: +- **Embedding layers** are replaced with proxy stubs that transform token IDs into embeddings +- **Language model (LM) head layers** are replaced with proxy implementations that convert hidden states to logits + +### Benefits +- **Simplified model export**: Easier to export models for compilation and deployment +- **IO file generation**: Automatically save input/output tensors for validation and debugging + + +## Enabling Proxy Mode + +To enable proxy models, use the `enable_proxy=True` parameter when loading a model: + +```python +from QEfficient import QEFFAutoModelForCausalLM + +model = QEFFAutoModelForCausalLM.from_pretrained( + model_name, + enable_proxy=True +) +``` + +### Saving Input/Output Files + +Generate IO files during inference using `write_io=True`: + +```python +model.generate( + inputs=..., + write_io=True # Saves input/output tensors to .npy files +) +``` + +## Example Files + +### 1. **text_model.py** - Text Generation (Causal Language Models) +Demonstrates proxy model usage with GPT2 for text generation. + +**Key Features:** +- Loads a causal language model with proxy enabled +- Compiles the model for inference +- Generates text with IO file output + +**Usage:** +```bash +python text_model.py +``` + +**Model:** `openai-community/gpt2` + +--- + +### 2. **embedding_model.py** - Text Embeddings +Shows how to enable proxy mode for embedding models that extract sentence/text embeddings. + +**Key Features:** +- Loads an embedding model with proxy enabled +- Supports pooling strategies (mean, CLS, etc.) +- Generates embeddings with IO file output + +**Usage:** +```bash +python embedding_model.py +``` + +**Model:** `BAAI/bge-base-en-v1.5` + +--- + +### 3. **audio_model.py** - Audio Processing +Demonstrates proxy models for two popular audio model types: + +#### a) Speech-to-Seq2Seq (Whisper) +- Transcribes audio to text using encoder-decoder architecture +- Model: `openai/whisper-tiny` + +#### b) CTC (Connectionist Temporal Classification) - Wav2Vec2 +- Direct audio-to-text transcription +- Model: `facebook/wav2vec2-base` + +**Key Features:** +- Processes audio samples with automatic feature extraction +- Supports both Seq2Seq and CTC-based models +- Generates IO files for validation + +**Usage:** +```bash +python audio_model.py +``` + +--- + +### 4. **image_model.py** - Vision-Language Models (Multimodal) +Demonstrates proxy models for advanced vision-language models with three different execution flows. + +#### Supported Model Types: + +1. **Standard VLM** (LLaVA, Gemma3, Granite Vision) + - Standard image-to-text architecture + - Model: `llava-hf/llava-1.5-7b-hf` + +2. **InternVL** + - Advanced vision-language model with custom architecture + - Model: `OpenGVLab/InternVL2_5-1B` + +3. **Molmo** + - Open-source multimodal model + - Model: `allenai/Molmo-7B-D-0924` + +**Key Features:** +- Handles image and text inputs +- Supports multiple VLM architectures with different preprocessing pipelines +- Generates captions/descriptions with IO file output +- KV cache offloading support (`kv_offload=True`) + +**Usage:** +```bash +python image_model.py +``` + +--- + +## Generated IO Files + +When `write_io=True`, the model generates files in the qeff models directory: +- `*.npy` files: NumPy arrays containing input/output tensors +- File names indicate tensor type and layer depth +- **Use case**: Validate model outputs, compare with baseline implementations, debug inference issues + + + + + + +--- + +## References + +- [QEfficient Documentation](https://quic.github.io/efficient-transformers/index.html) +- [Model Hub](https://huggingface.co/models) +- [Transformers Documentation](https://huggingface.co/docs/transformers/) + diff --git a/scripts/debug/audio_model.py b/scripts/debug/audio_model.py new file mode 100644 index 000000000..98bad0ed6 --- /dev/null +++ b/scripts/debug/audio_model.py @@ -0,0 +1,65 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +# SPDX-License-Identifier: BSD-3-Clause +# +# ---------------------------------------------------------------------------- + +""" +Simple example: How to enable proxy models for audio processing and generate IO files. +Demonstrates two model types: Speech-to-Seq2Seq (Whisper) and CTC (Wav2Vec2). +""" + +from datasets import load_dataset +from transformers import AutoProcessor + +from QEfficient import QEFFAutoModelForCTC, QEFFAutoModelForSpeechSeq2Seq + +print("Loading audio sample...") +dataset = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation") +audio_data = dataset[0]["audio"]["array"] +sample_rate = dataset[0]["audio"]["sampling_rate"] + +# =================================================================== +# ============ Model Type 1: Speech-to-Seq2Seq (Whisper) ============ +# =================================================================== + +print("\n" + "=" * 70) +print("MODEL 1: WHISPER (Speech-to-Seq2Seq)") +print("=" * 70) + +model_name_seq2seq = "openai/whisper-tiny" +processor_seq2seq = AutoProcessor.from_pretrained(model_name_seq2seq) + +# Load proxy model +model_seq2seq = QEFFAutoModelForSpeechSeq2Seq.from_pretrained(model_name_seq2seq, enable_proxy=True) +print(model_seq2seq) + +model_seq2seq.compile(num_cores=16) + +inputs = processor_seq2seq(audio_data, sampling_rate=sample_rate, return_tensors="pt") +result = model_seq2seq.generate(inputs=inputs, generation_len=25, write_io=True) +transcription = processor_seq2seq.batch_decode(result.generated_ids)[0] +print(f"Transcription: {transcription}\n") + + +# =================================================================== +# ============ Model Type 2: CTC (Wav2Vec2) ============ +# =================================================================== + +print("=" * 70) +print("MODEL 2: WAV2VEC2 (CTC)") +print("=" * 70) + +model_name_ctc = "facebook/wav2vec2-base" +processor_ctc = AutoProcessor.from_pretrained(model_name_ctc) + +# Load proxy model +model_ctc = QEFFAutoModelForCTC.from_pretrained(model_name_ctc, enable_proxy=True) +print(model_ctc) + +model_ctc.compile(num_cores=16) + +# Generate with IO files +transcription = model_ctc.generate(processor_ctc, inputs=audio_data, write_io=True) +print(f"Transcription: {transcription}\n") diff --git a/scripts/debug/embedding_model.py b/scripts/debug/embedding_model.py new file mode 100644 index 000000000..99d406e9a --- /dev/null +++ b/scripts/debug/embedding_model.py @@ -0,0 +1,29 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +# SPDX-License-Identifier: BSD-3-Clause +# +# ---------------------------------------------------------------------------- + +""" +Simple example: How to enable proxy model for embeddings and generate IO files. +""" + +from transformers import AutoTokenizer + +from QEfficient import QEFFAutoModel + +model_name = "BAAI/bge-base-en-v1.5" +test_text = "My name is John" + +# Load proxy model (enable_proxy=True replaces embeddings with proxy implementations) +model = QEFFAutoModel.from_pretrained(model_name, pooling="mean", enable_proxy=True) + +model.compile(num_cores=16) + +tokenizer = AutoTokenizer.from_pretrained(model_name) +inputs = tokenizer(test_text, return_tensors="pt") + +# Generate embeddings with IO files +output = model.generate(inputs, write_io=True) +print(output) diff --git a/scripts/debug/image_model.py b/scripts/debug/image_model.py new file mode 100644 index 000000000..6aecc0b3b --- /dev/null +++ b/scripts/debug/image_model.py @@ -0,0 +1,179 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +# SPDX-License-Identifier: BSD-3-Clause +# +# ---------------------------------------------------------------------------- + +""" +Simple example: How to enable proxy models for three different vision-language models and generate IO files. +Demonstrates three model types with different execution flows: +1. Standard VLM (LLaVA, Gemma3, granite_vision, etc.) +2. InternVL Model +3. Molmo Model +""" + +from io import BytesIO + +import requests +import torch +from PIL import Image +from transformers import ( + AutoConfig, + AutoModelForCausalLM, + AutoProcessor, + AutoTokenizer, +) + +from QEfficient.transformers.models.modeling_auto import QEFFAutoModelForCausalLM, QEFFAutoModelForImageTextToText +from QEfficient.utils.test_utils import InternProcessor + +img_url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/0052a70beed5bf71b92610a43a52df6d286cd5f3/diffusers/rabbit.jpg" +query = "Describe this image." + +print("Loading image...") +img = requests.get(img_url, stream=True) +image = Image.open(BytesIO(img.content)).convert("RGB") + +# Three models with different execution flows +models = [ + { + "name": "llava-hf/llava-1.5-7b-hf", + "type": "Standard VLM", + "is_intern": False, + "is_molmo": False, + }, + { + "name": "OpenGVLab/InternVL2_5-1B", + "type": "InternVL", + "is_intern": True, + "is_molmo": False, + }, + { + "name": "allenai/Molmo-7B-D-0924", + "type": "Molmo", + "is_intern": False, + "is_molmo": True, + }, +] + +for model_config in models: + model_name = model_config["name"] + model_type = model_config["type"] + is_intern_model = model_config["is_intern"] + is_molmo_model = model_config["is_molmo"] + + print("\n" + "=" * 70) + print(f"MODEL: {model_name}") + print(f"TYPE: {model_type}") + print("=" * 70) + + config = AutoConfig.from_pretrained(model_name, trust_remote_code=True, padding=not is_molmo_model) + config._attn_implementation = "eager" if (is_intern_model or is_molmo_model) else None + + # ============ EXECUTION FLOW 1: Standard VLM (LLaVA) ============ + compile_kwargs = {} + if not is_intern_model and not is_molmo_model: + print("Execution Flow: Standard VLM") + + processor = AutoProcessor.from_pretrained(model_name, trust_remote_code=True, padding=True) + + # Prepare conversation + conversation = [ + { + "role": "user", + "content": [ + {"type": "text", "text": query}, + {"type": "image"}, + ], + } + ] + + prompt = processor.apply_chat_template(conversation, add_generation_prompt=True) + inputs = processor(images=image, text=prompt, return_tensors="pt") + + if "pixel_values" in inputs: + inputs["pixel_values"] = inputs["pixel_values"].to(torch.float32) + + # Load proxy model + qeff_model = QEFFAutoModelForImageTextToText.from_pretrained(model_name, kv_offload=True, enable_proxy=True) + + # ============ EXECUTION FLOW 2: InternVL Model ============ + elif is_intern_model: + print("Execution Flow: InternVL") + + model_hf = AutoModelForCausalLM.from_pretrained( + model_name, + low_cpu_mem_usage=False, + trust_remote_code=True, + config=config, + ) + + tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True, use_fast=False) + processor = InternProcessor(model_hf, tokenizer) + + # Process image + image_resized = image.resize((448, 448)) + pixel_value = processor.load_image(image_resized, max_num=12) + + # Prepare prompt + question = "\n" + query + messages = [] + roles = ("<|im_start|>user\n", "<|im_start|>assistant\n") + prompt = processor( + pixel_value.unsqueeze(0), [question], messages, roles, num_patches_list=[pixel_value.shape[0]] + ) + + inputs = tokenizer(prompt, return_tensors="pt") + inputs["pixel_values"] = pixel_value.clone() + + # Load proxy model + qeff_model = QEFFAutoModelForCausalLM.from_pretrained( + model_name, + config=config, + kv_offload=True, + enable_proxy=True, + ) + + compile_kwargs["num_patches"] = 1 + + # ============ EXECUTION FLOW 3: Molmo Model ============ + else: # is_molmo_model + print("Execution Flow: Molmo") + + processor = AutoProcessor.from_pretrained(model_name, trust_remote_code=True, padding=True) + + # Resize image for Molmo + image_resized = image.resize((536, 354)) + + # Process inputs + inputs = processor.process(images=[image_resized], text=query) + inputs = {k: v.unsqueeze(0) for k, v in inputs.items()} + + # Add required fields for Molmo + inputs["attention_mask"] = torch.ones((inputs["input_ids"].shape), dtype=torch.int64) + valid = inputs["image_input_idx"] > 0 + valid = valid.reshape(1, -1) + inputs["valid_idx"] = torch.nonzero(valid)[:, 1].unsqueeze(0) + inputs["pixel_values"] = inputs.pop("images") + + # Load proxy model + qeff_model = QEFFAutoModelForCausalLM.from_pretrained( + model_name, + config=config, + trust_remote_code=True, + kv_offload=True, + enable_proxy=True, + ) + + print("Compiling model...") + qeff_model.compile(num_devices=1, prefill_seq_len=128, ctx_len=2048, **compile_kwargs) + + # Generate with IO files + outputs = qeff_model.generate( + inputs=inputs, + generation_len=10, + write_io=True, # Saves input/output tensors to files + ) + print(f"Output: {outputs}\n") + print(f"✓ Successfully processed: {model_name}\n") diff --git a/scripts/debug/text_model.py b/scripts/debug/text_model.py new file mode 100644 index 000000000..528180c30 --- /dev/null +++ b/scripts/debug/text_model.py @@ -0,0 +1,29 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +# SPDX-License-Identifier: BSD-3-Clause +# +# ---------------------------------------------------------------------------- + +""" +Simple example: How to enable proxy model and generate IO files. +""" + +from transformers import AutoTokenizer + +from QEfficient import QEFFAutoModelForCausalLM + +model_name = "openai-community/gpt2" + +# Load proxy model (enable_proxy=True replaces embedding and LM head with proxy implementations) +model = QEFFAutoModelForCausalLM.from_pretrained(model_name, enable_proxy=True) + +model.compile(num_cores=16) + +# Generate with IO files +tokenizer = AutoTokenizer.from_pretrained(model_name) +model.generate( + prompts=["Hi there!!"], + tokenizer=tokenizer, + write_io=True, # Saves input/output tensors to files +) diff --git a/scripts/pr_report/__init__.py b/scripts/pr_report/__init__.py new file mode 100644 index 000000000..efcc11246 --- /dev/null +++ b/scripts/pr_report/__init__.py @@ -0,0 +1,49 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +# SPDX-License-Identifier: BSD-3-Clause +# +# ----------------------------------------------------------------------------- + +""" +QEfficient Memory Profiling + +A production-ready memory profiling solution specifically designed for QEfficient workflows. +Provides manual operation marking, comprehensive metrics collection, and professional visualization. + +Usage Example: + +```python +from scripts.memory_profiling import QEffMemoryProfiler + +profiler = QEffMemoryProfiler(verbose=True) +profiler.start_monitoring() +# ... your QEfficient code ... +profiler.stop_monitoring() +print(profiler.get_memory_report()) +profiler.generate_memory_graph() +``` +""" + +# Core profiler components +from .profiler import ( + MetricsCollector, + ProfilerConfig, + ProfileSample, + QEffMemoryProfiler, +) + +# Visualization component (imported on-demand) +try: + from .visualizer import QEffMemoryVisualizer +except ImportError: + # Handle case where matplotlib is not available + QEffMemoryVisualizer = None + +__all__ = [ + "QEffMemoryProfiler", + "ProfilerConfig", + "ProfileSample", + "MetricsCollector", + "QEffMemoryVisualizer", +] diff --git a/scripts/pr_report/pr_dashboard.py b/scripts/pr_report/pr_dashboard.py new file mode 100644 index 000000000..93d84ee2c --- /dev/null +++ b/scripts/pr_report/pr_dashboard.py @@ -0,0 +1,498 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +# SPDX-License-Identifier: BSD-3-Clause +# +# ----------------------------------------------------------------------------- + +""" +Daily PR report generator. + +Outputs a Markdown table to stdout and writes +scripts/git_workflow/recipients.txt with resolved email addresses. +""" + +import json +import math +import os +import sys +import time +import urllib.error +import urllib.parse +import urllib.request +from datetime import datetime, timezone + +API = "https://api.github.com" +ACCEPT = "application/vnd.github+json" + +# ── GitHub API helpers ──────────────────────────────────────────────────────── + + +def gh_request(path, token, params=None): + """ + Make a single GitHub API request with up to 3 retries on rate-limit errors. + Returns (parsed_json, headers). + """ + url = API + path + if params: + url += "?" + urllib.parse.urlencode(params) + + req = urllib.request.Request(url) + req.add_header("Accept", ACCEPT) + req.add_header("Authorization", f"Bearer {token}") + req.add_header("X-GitHub-Api-Version", "2022-11-28") + + for attempt in range(3): + try: + with urllib.request.urlopen(req) as resp: + return json.loads(resp.read().decode("utf-8")), resp.headers + except urllib.error.HTTPError as e: + body = e.read().decode("utf-8", errors="replace") + # Retry on rate-limit (403 with rate-limit body, or 429) + if e.code == 429 or (e.code == 403 and "rate limit" in body.lower()): + wait = 60 * (attempt + 1) + print( + f"Rate limited on {path} (attempt {attempt + 1}/3), waiting {wait}s …", + file=sys.stderr, + ) + time.sleep(wait) + continue + print(f"HTTP {e.code} for {path}: {body[:300]}", file=sys.stderr) + raise + except urllib.error.URLError as e: + print(f"URL error for {path}: {e.reason}", file=sys.stderr) + if attempt < 2: + time.sleep(5 * (attempt + 1)) + continue + raise + + raise RuntimeError(f"GitHub API request failed after 3 retries: {path}") + + +def paginate(path, token, params=None): + """ + Fetch all pages from a GitHub list endpoint. + Uses the Link header (rel="next") for correct pagination — avoids the + off-by-one bug of stopping when len(chunk) == 100. + """ + page = 1 + out = [] + while True: + p = dict(params or {}) + p.update({"per_page": 100, "page": page}) + chunk, headers = gh_request(path, token, p) + if not chunk: + break + out.extend(chunk) + # Stop only when GitHub says there is no next page + if 'rel="next"' not in (headers.get("Link") or ""): + break + page += 1 + return out + + +def paginate_check_runs(path, token, params=None): + """ + Paginate the check-runs endpoint, which wraps results in + {"check_runs": [...], "total_count": N} instead of a plain list. + """ + page = 1 + out = [] + while True: + p = dict(params or {}) + p.update({"per_page": 100, "page": page}) + resp, headers = gh_request(path, token, p) + chunk = resp.get("check_runs", []) + out.extend(chunk) + if 'rel="next"' not in (headers.get("Link") or ""): + break + page += 1 + return out + + +# ── Utility helpers ─────────────────────────────────────────────────────────── + + +def parse_iso(dt): + return datetime.fromisoformat(dt.replace("Z", "+00:00")) + + +def is_bot(username): + """Filter out GitHub bot accounts (e.g. github-actions[bot], dependabot[bot]).""" + return "[bot]" in username + + +def summarize_reviews(reviews): + """ + Keep the latest meaningful review state per human reviewer. + Bot accounts are excluded. + States: APPROVED, CHANGES_REQUESTED, COMMENTED, DISMISSED, PENDING + """ + latest = {} + for r in sorted(reviews, key=lambda x: x.get("submitted_at") or ""): + user = (r.get("user") or {}).get("login", "unknown") + if is_bot(user): + continue + state = r.get("state", "UNKNOWN") + latest[user] = state + + approvers = sorted([u for u, s in latest.items() if s == "APPROVED"]) + changers = sorted([u for u, s in latest.items() if s == "CHANGES_REQUESTED"]) + commenters = sorted([u for u, s in latest.items() if s == "COMMENTED"]) + dismissed = sorted([u for u, s in latest.items() if s == "DISMISSED"]) + + return { + "approvers": approvers, + "changes_requested": changers, + "commenters": commenters, + "dismissed": dismissed, + "latest_map": latest, + } + + +def determine_pending_with(pr, reviews, reviews_summary, requested_reviewers): + """ + Determine who the PR is currently pending with, based on its state. + + Rules (in priority order): + 1. Draft → author (still being worked on) + 2. No reviews yet, reviewers assigned → requested reviewers + 3. No reviews yet, no reviewers assigned → author + 4. Changes requested AND no new commits since the review (unresolved) → author + 5. Changes requested AND author pushed new commits after the review (resolved) → reviewer(s) who requested changes + 6. All approved, no outstanding change requests → author (ready to merge) + 7. Only comments → requested reviewers if any, else author + + "Resolved" is detected by comparing the PR's current head SHA against the + commit_id recorded on the last CHANGES_REQUESTED review for each reviewer. + If head_sha != that commit_id, the author has pushed new commits since the + review — meaning they have addressed the feedback. + """ + author = (pr.get("user") or {}).get("login", "unknown") + is_draft = pr.get("draft", False) + head_sha = (pr.get("head") or {}).get("sha", "") + + # 1. Draft → author + if is_draft: + return author + + changes_requesters = reviews_summary["changes_requested"] + approvers = reviews_summary["approvers"] + + # 2 & 3. No reviews yet + if not changes_requesters and not approvers and not reviews_summary["commenters"]: + if requested_reviewers: + return ", ".join(requested_reviewers) + return author + + # 4 & 5. Outstanding change requests + if changes_requesters: + # For each reviewer whose latest state is CHANGES_REQUESTED, find the + # commit_id of their most recent CHANGES_REQUESTED review. + last_cr_commit_per_reviewer = {} + for r in sorted(reviews, key=lambda x: x.get("submitted_at") or ""): + user = (r.get("user") or {}).get("login", "unknown") + if is_bot(user): + continue + if r.get("state") == "CHANGES_REQUESTED": + last_cr_commit_per_reviewer[user] = r.get("commit_id", "") + + # Split reviewers into "resolved" (new commits pushed) vs "unresolved" + resolved_reviewers = [] + unresolved_reviewers = [] + for reviewer in changes_requesters: + cr_commit = last_cr_commit_per_reviewer.get(reviewer, "") + if cr_commit and head_sha and cr_commit != head_sha: + resolved_reviewers.append(reviewer) + else: + unresolved_reviewers.append(reviewer) + + if unresolved_reviewers: + # At least one reviewer's changes haven't been addressed yet → author + return author + else: + # All change requests have new commits pushed after them → pending re-review + return ", ".join(resolved_reviewers) + + # 6. All approved, no outstanding change requests → author (ready to merge) + if approvers and not changes_requesters: + return author + + # 7. Only comments → requested reviewers if any, else author + if requested_reviewers: + return ", ".join(requested_reviewers) + return author + + +def format_check_runs(check_runs): + """ + Return each individual check run name and its status. + Format: "job-name: PASS / job-name2: FAIL / ..." + """ + if not check_runs: + return "NONE" + + results = [] + for cr in sorted(check_runs, key=lambda x: x.get("name", "")): + name = cr.get("name", "unknown") + status = cr.get("status") + conclusion = cr.get("conclusion") + + if status != "completed" or conclusion is None: + state = "PENDING" + elif conclusion in ("failure", "cancelled", "timed_out", "action_required", "stale"): + state = "FAIL" + elif conclusion in ("success", "neutral", "skipped"): + state = "PASS" + else: + state = conclusion.upper() + + results.append(f"{name}: {state}") + + return " / ".join(results) + + +# ── Pie chart helper ────────────────────────────────────────────────────────── + + +def generate_pie_chart_svg(author_counts): + """ + Generate a self-contained inline SVG pie chart showing PR distribution + by author. Returns an HTML string (a
wrapping an ) that can + be embedded directly in Markdown — the markdown library passes raw HTML + blocks through unchanged. + """ + if not author_counts: + return "" + + # Sort by count descending so the largest slice starts at the top + items = sorted(author_counts.items(), key=lambda x: -x[1]) + total = sum(v for _, v in items) + + # 15-colour palette; cycles if there are more authors + colors = [ + "#4a90d9", + "#e74c3c", + "#2ecc71", + "#f39c12", + "#9b59b6", + "#1abc9c", + "#e67e22", + "#3498db", + "#e91e63", + "#00bcd4", + "#ff5722", + "#607d8b", + "#795548", + "#9c27b0", + "#4caf50", + ] + + cx, cy, r = 190, 190, 160 # pie centre and radius + legend_x = cx * 2 + 30 # legend column starts here + row_h = 22 # legend row height + svg_w = legend_x + 260 # total SVG width + svg_h = max(cy * 2, len(items) * row_h + 50) # total SVG height + + # ── Build slice paths ──────────────────────────────────────────────────── + paths_svg = "" + legend_svg = "" + start_angle = -math.pi / 2 # begin at 12 o'clock + + for i, (author, count) in enumerate(items): + angle = 2 * math.pi * count / total + end_angle = start_angle + angle + + x1 = cx + r * math.cos(start_angle) + y1 = cy + r * math.sin(start_angle) + x2 = cx + r * math.cos(end_angle) + y2 = cy + r * math.sin(end_angle) + + large_arc = 1 if angle > math.pi else 0 + color = colors[i % len(colors)] + pct = count / total * 100 + + # SVG arc path: move to centre → line to arc start → arc → close + path = f"M {cx},{cy} L {x1:.2f},{y1:.2f} A {r},{r} 0 {large_arc},1 {x2:.2f},{y2:.2f} Z" + paths_svg += ( + f' \n' + f" {author}: {count} PR{'s' if count != 1 else ''} ({pct:.1f}%)\n" + f" \n" + ) + + # Legend row + ly = 40 + i * row_h + legend_svg += ( + f' \n' + f' ' + f"{author} {count} PR{'s' if count != 1 else ''} ({pct:.1f}%)" + f"\n" + ) + + start_angle = end_angle + + # ── Assemble SVG ───────────────────────────────────────────────────────── + svg = ( + f'
\n' + f'\n' + # Chart title + f' ' + f"PR Distribution by Author (Total: {total})\n" + # Slices + + paths_svg + # Legend header + + f' Author\n' + # Legend rows + + legend_svg + "\n
\n" + ) + return svg + + +# ── Email list helper ───────────────────────────────────────────────────────── + + +def load_email_list(path): + """ + Load email_map.json — a plain JSON array of email addresses. + Returns a list of strings. + """ + try: + with open(path) as f: + data = json.load(f) + if not isinstance(data, list): + print(f"Warning: {path} should be a JSON array of email addresses.", file=sys.stderr) + return [] + return [e for e in data if isinstance(e, str) and e.strip()] + except FileNotFoundError: + print(f"Warning: email list not found at {path}", file=sys.stderr) + return [] + + +# ── Main ────────────────────────────────────────────────────────────────────── + + +def main(): + token = os.environ.get("GITHUB_TOKEN") + if not token: + print("Missing GITHUB_TOKEN", file=sys.stderr) + sys.exit(1) + + repo_full = os.environ.get("GITHUB_REPOSITORY") # owner/repo + if not repo_full or "/" not in repo_full: + print("Missing/invalid GITHUB_REPOSITORY", file=sys.stderr) + sys.exit(1) + + owner, repo = repo_full.split("/", 1) + now = datetime.now(timezone.utc) + date_str = now.strftime("%B %d, %Y %H:%M UTC") + + # Load recipient email list (path configurable via EMAIL_MAP_FILE env var) + script_dir = os.path.dirname(os.path.abspath(__file__)) + default_map = os.path.join(script_dir, "email_map.json") + email_map_path = os.environ.get("EMAIL_MAP_FILE", default_map) + recipients = load_email_list(email_map_path) + + # 1) Fetch all open PRs (correctly paginated via Link header) + pulls = paginate(f"/repos/{owner}/{repo}/pulls", token, params={"state": "open"}) + total_open = len(pulls) + + # -- Header --------------------------------------------------------------- + print(f"# Open PR Dashboard — {owner}/{repo}") + print() + print("| | |") + print("|---|---|") + print(f"| Report Date | {date_str} |") + print(f"| Open PRs | **{total_open}** |") + print() + + # -- Pie chart (author distribution) — collected in first pass ------------ + author_counts: dict = {} + for pr in pulls: + author = (pr.get("user") or {}).get("login", "unknown") + if not is_bot(author): + author_counts[author] = author_counts.get(author, 0) + 1 + + print(generate_pie_chart_svg(author_counts)) + + # -- Table ---------------------------------------------------------------- + print( + "| PR | Author | Assignee | Age (days) | Draft | Labels | Reviewers | Pending With | Review Summary | CI Checks |" + ) + print("|---|---|---|---:|:---:|---|---|---|---|---|") + + for pr in pulls: + number = pr["number"] + title = pr.get("title", "").replace("|", "\\|") + url = pr.get("html_url", "") + author = (pr.get("user") or {}).get("login", "unknown") + draft = "Yes" if pr.get("draft") else "No" + created_at = parse_iso(pr["created_at"]) + age_days = (now - created_at).days + head_sha = (pr.get("head") or {}).get("sha") + + # Assignees (already in PR payload — no extra API call) + assignees = [u["login"] for u in pr.get("assignees") or [] if not is_bot(u["login"])] + assignee_str = ", ".join(assignees) if assignees else "—" + + # Labels (already in PR payload — no extra API call) + labels = [lbl["name"].replace("|", "\\|") for lbl in pr.get("labels") or []] + labels_str = ", ".join(labels) if labels else "—" + + # 2) Requested reviewers + rr, _ = gh_request(f"/repos/{owner}/{repo}/pulls/{number}/requested_reviewers", token) + users = [u["login"] for u in rr.get("users", []) if not is_bot(u["login"])] + teams = [t["name"] for t in rr.get("teams", [])] + requested_reviewers = users + [f"team:{t}" for t in teams] + reviewers_str = ", ".join(requested_reviewers) if requested_reviewers else "—" + + # 3) Reviews submitted (paginated, bots excluded) + reviews = paginate(f"/repos/{owner}/{repo}/pulls/{number}/reviews", token) + rs = summarize_reviews(reviews) + parts = [] + if rs["changes_requested"]: + parts.append("Changes Requested: " + ", ".join(rs["changes_requested"])) + if rs["approvers"]: + parts.append("Approved: " + ", ".join(rs["approvers"])) + if rs["commenters"]: + parts.append("Commented: " + ", ".join(rs["commenters"])) + if rs["dismissed"]: + parts.append("Dismissed: " + ", ".join(rs["dismissed"])) + if not parts: + parts.append("No reviews yet") + review_summary = " / ".join(parts) + + # Pending With — smart assignment based on PR state + pending_with_str = determine_pending_with(pr, reviews, rs, requested_reviewers) + + # 4) Individual CI check runs — fully paginated + ci_str = "UNKNOWN" + if head_sha: + check_runs = paginate_check_runs( + f"/repos/{owner}/{repo}/commits/{head_sha}/check-runs", + token, + params={"filter": "latest"}, + ) + ci_str = format_check_runs(check_runs) + + pr_label = f"[#{number}]({url}) {title}" + print( + f"| {pr_label} | {author} | {assignee_str} | {age_days} | {draft} | {labels_str} | {reviewers_str} | {pending_with_str} | {review_summary} | {ci_str} |" + ) + + # -- Write recipients.txt ------------------------------------------------- + recipients_path = os.path.join(script_dir, "recipients.txt") + with open(recipients_path, "w") as f: + f.write(", ".join(recipients)) + + print(f"recipients written to {recipients_path} ({len(recipients)} addresses)", file=sys.stderr) + + +if __name__ == "__main__": + main() diff --git a/tests/cloud/test_infer.py b/tests/cloud/test_infer.py index e11f69017..ed3352903 100644 --- a/tests/cloud/test_infer.py +++ b/tests/cloud/test_infer.py @@ -5,6 +5,8 @@ # # ----------------------------------------------------------------------------- +from types import SimpleNamespace + import pytest import QEfficient @@ -12,7 +14,13 @@ def check_infer( - mocker, model_name, prompt="My name is", full_batch_size=None, enable_qnn=False, image_url=None, generation_len=20 + mocker, + model_name, + prompt="My name is", + full_batch_size=None, + enable_qnn=False, + image_url=None, + generation_len=20, ): check_and_assign_cache_dir_spy = mocker.spy(QEfficient.cloud.infer, "check_and_assign_cache_dir") qeff_model_load_spy = mocker.spy(QEfficient.cloud.infer.QEFFCommonLoader, "from_pretrained") @@ -99,3 +107,42 @@ def test_infer_vlm(mocker): prompt="Describe the image.", image_url="https://i.etsystatic.com/8155076/r/il/0825c2/1594869823/il_fullxfull.1594869823_5x0w.jpg", ) + + +class _DummyQEFFModel: + def __init__(self, architecture): + self.model = SimpleNamespace(config=SimpleNamespace(architectures=[architecture])) + self.compile_kwargs = None + + def compile(self, **kwargs): + self.compile_kwargs = kwargs + return "/tmp/qpc" + + def generate(self, *args, **kwargs): + return {} + + +def _run_infer_with_dummy_model(mocker, architecture, **infer_kwargs): + dummy_model = _DummyQEFFModel(architecture=architecture) + mocker.patch.object(QEfficient.cloud.infer, "check_and_assign_cache_dir", return_value="/tmp/cache") + mocker.patch.object(QEfficient.cloud.infer.QEFFCommonLoader, "from_pretrained", return_value=dummy_model) + mocker.patch.object(QEfficient.cloud.infer, "load_hf_tokenizer", return_value=object()) + + infer( + model_name="dummy/model", + num_cores=16, + prompt=["hello"], + generation_len=1, + **infer_kwargs, + ) + return dummy_model + + +def test_infer_enables_onnx_subfunctions_when_explicitly_set(mocker): + dummy_model = _run_infer_with_dummy_model(mocker, architecture="Qwen3MoeForCausalLM", use_onnx_subfunctions=True) + assert dummy_model.compile_kwargs["use_onnx_subfunctions"] is True + + +def test_infer_keeps_onnx_subfunctions_disabled_by_default(mocker): + dummy_model = _run_infer_with_dummy_model(mocker, architecture="LlamaForCausalLM") + assert dummy_model.compile_kwargs["use_onnx_subfunctions"] is False diff --git a/tests/sample_model_tests_cpu/test_model_quickcheck.py b/tests/sample_model_tests_cpu/test_model_quickcheck.py new file mode 100644 index 000000000..3b70beeb1 --- /dev/null +++ b/tests/sample_model_tests_cpu/test_model_quickcheck.py @@ -0,0 +1,463 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +# SPDX-License-Identifier: BSD-3-Clause +# +# ----------------------------------------------------------------------------- +""" +Fast CPU regression coverage across the main model families supported by QEfficient. + +This file intentionally uses two coverage tiers: + +1. Runtime parity: + - Exact token or tensor parity across HF PyTorch, transformed PyTorch, and ORT + - Used where the repo already has a stable CPU verification path +2. Export smoke: + - Used for model families or architectures that are supported by export today, + but do not yet have a stable CPU runtime parity path in the consolidated test +""" + +import logging +import os +import shutil +import tempfile +from contextlib import contextmanager, redirect_stderr, redirect_stdout +from io import StringIO +from pathlib import Path +from typing import Dict + +import numpy as np +import onnx +import onnxruntime as ort +import pytest +import torch +from transformers import ( + AutoConfig, + AutoModel, + AutoModelForCausalLM, + AutoModelForCTC, + AutoModelForSequenceClassification, + AutoModelForSpeechSeq2Seq, + AutoTokenizer, + Qwen2Config, +) + +from QEfficient.transformers.models.modeling_auto import ( + QEFFAutoModel, + QEFFAutoModelForCausalLM, + QEFFAutoModelForCTC, + QEFFAutoModelForImageTextToText, + QEFFAutoModelForSequenceClassification, + QEFFAutoModelForSpeechSeq2Seq, +) +from QEfficient.transformers.quantizers.auto import replace_transformers_quantizers +from QEfficient.utils.run_utils import ApiRunner + +ort.set_default_logger_severity(3) +logging.getLogger("QEfficient").setLevel(logging.ERROR) +logging.getLogger("QEfficient.base.modeling_qeff").setLevel(logging.ERROR) + + +CAUSAL_RUNTIME_MODEL_IDS = { + "gpt2": "hf-internal-testing/tiny-random-GPT2LMHeadModel", + "codegen": "hf-internal-testing/tiny-random-CodeGenForCausalLM", + "falcon": "hf-internal-testing/tiny-random-FalconForCausalLM", + "gptj": "hf-internal-testing/tiny-random-GPTJForCausalLM", + "llama": "hf-internal-testing/tiny-random-LlamaForCausalLM", + "mistral": "hf-internal-testing/tiny-random-MistralForCausalLM", + "mixtral": "hf-internal-testing/tiny-random-MixtralForCausalLM", + "mpt": "hf-internal-testing/tiny-random-MptForCausalLM", + "phi": "hf-internal-testing/tiny-random-PhiForCausalLM", + "phi3": "tiny-random/phi-4", + "qwen2": "yujiepan/qwen2-tiny-random", + "starcoder2": "hf-internal-testing/tiny-random-Starcoder2ForCausalLM", + "granite": "hf-internal-testing/tiny-random-GraniteForCausalLM", + "olmo2": "hf-internal-testing/tiny-random-Olmo2ForCausalLM", + "gpt_oss": "tiny-random/gpt-oss-bf16", +} + +VLM_TEXT_RUNTIME_MODEL_ID = "tiny-random/gemma-3" +VLM_EXPORT_MODEL_IDS = { + "gemma3": "tiny-random/gemma-3", + "qwen2_5_vl": "optimum-intel-internal-testing/tiny-random-qwen2.5-vl", + "internvl2": "optimum-intel-internal-testing/tiny-random-internvl2", +} +TINY_TEXT_EMBEDDING_MODEL_ID = "hf-internal-testing/tiny-random-BertModel" +TINY_AUDIO_CTC_MODEL_ID = "hf-internal-testing/tiny-random-wav2vec2" +TINY_WHISPER_MODEL_ID = "hf-internal-testing/tiny-random-WhisperForConditionalGeneration" +TINY_SEQ_CLASSIFICATION_MODEL_ID = "ydshieh/tiny-random-BertForSequenceClassification" +TINY_AWQ_MODEL_ID = "optimum-intel-internal-testing/tiny-mixtral-AWQ-4bit" + +MODEL_KWARGS = {"attn_implementation": "eager"} +PREFIX_CACHING_MODEL_ID = "hf-internal-testing/tiny-random-GPT2LMHeadModel" + + +def _per_test_thread_budget() -> int: + override = os.environ.get("QEFF_NUM_THREADS") + if override: + return max(1, int(override)) + total = os.cpu_count() or 1 + workers = max(1, int(os.environ.get("PYTEST_XDIST_WORKER_COUNT", "1"))) + return max(1, total // workers) + + +def _configure_torch_threads() -> None: + threads = _per_test_thread_budget() + os.environ.setdefault("OMP_NUM_THREADS", str(threads)) + os.environ.setdefault("MKL_NUM_THREADS", str(threads)) + torch.set_num_threads(threads) + torch.set_num_interop_threads(max(1, min(4, threads))) + + +def _ort_session(onnx_path: Path) -> ort.InferenceSession: + options = ort.SessionOptions() + threads = _per_test_thread_budget() + options.intra_op_num_threads = threads + options.inter_op_num_threads = 1 + return ort.InferenceSession(str(onnx_path), sess_options=options) + + +_configure_torch_threads() + + +def _cleanup_stale_tmp_exports() -> None: + tmp_root = Path(tempfile.gettempdir()) + for pattern in ("qeff_*", "*qeff*", "*onnx*", "*qnn*"): + for path in tmp_root.glob(pattern): + try: + if path.is_dir(): + shutil.rmtree(path, ignore_errors=True) + elif path.is_file(): + path.unlink(missing_ok=True) + except OSError: + # Best-effort cleanup only. + pass + + +@pytest.fixture(scope="session", autouse=True) +def _clean_tmp_exports_before_quickcheck(): + # Avoid concurrent cleanup from all xdist workers. + worker = os.environ.get("PYTEST_XDIST_WORKER") + if worker not in (None, "gw0"): + return + _cleanup_stale_tmp_exports() + + +@contextmanager +def _suppress_native_output(): + devnull_fd = os.open(os.devnull, os.O_WRONLY) + saved_stdout_fd = os.dup(1) + saved_stderr_fd = os.dup(2) + try: + os.dup2(devnull_fd, 1) + os.dup2(devnull_fd, 2) + with redirect_stdout(StringIO()), redirect_stderr(StringIO()): + yield + finally: + os.dup2(saved_stdout_fd, 1) + os.dup2(saved_stderr_fd, 2) + os.close(saved_stdout_fd) + os.close(saved_stderr_fd) + os.close(devnull_fd) + + +def _exported_onnx_path(export_result) -> Path: + if isinstance(export_result, (list, tuple)): + export_result = export_result[-1] + onnx_path = Path(export_result) + assert onnx_path.is_file() + return onnx_path + + +def _assert_has_retained_state_outputs(onnx_path: Path) -> None: + onnx_model = onnx.load(onnx_path, load_external_data=False) + retained_outputs = [output.name for output in onnx_model.graph.output if output.name.endswith("_RetainedState")] + assert retained_outputs + + +def _run_embedding_ort(onnx_path: Path, inputs: Dict[str, torch.Tensor]) -> np.ndarray: + session = _ort_session(onnx_path) + input_names = {item.name for item in session.get_inputs()} + ort_inputs = {name: tensor.detach().numpy() for name, tensor in inputs.items() if name in input_names} + return session.run(None, ort_inputs)[0] + + +def _run_whisper_export_smoke(qeff_model: QEFFAutoModelForSpeechSeq2Seq, out_dir: Path) -> Path: + onnx_path = _exported_onnx_path(qeff_model.export(out_dir)) + _assert_has_retained_state_outputs(onnx_path) + return onnx_path + + +def _skip_on_model_fetch_error(exc: Exception, model_id: str) -> None: + pytest.skip( + f"Skipping {model_id}: model unavailable or unsupported in this environment ({type(exc).__name__}: {exc})" + ) + + +def _export_vlm_with_text_fallback(model_id: str, out_dir: Path) -> Path: + try: + config = AutoConfig.from_pretrained(model_id, trust_remote_code=True) + model_type = getattr(config, "model_type", "") + use_text_only_first = model_type in {"qwen2_5_vl", "internvl_chat"} + + if not use_text_only_first: + try: + vlm_model = QEFFAutoModelForImageTextToText.from_pretrained(model_id, trust_remote_code=True) + return _exported_onnx_path(vlm_model.export(out_dir / "full-vlm")) + except Exception: + pass + + try: + if model_type == "qwen2_5_vl" and getattr(config, "text_config", None) is not None: + qwen2_cfg_dict = config.text_config.to_dict() + qwen2_cfg_dict["model_type"] = "qwen2" + qwen2_allowed_keys = set(Qwen2Config().to_dict().keys()) + qwen2_cfg = Qwen2Config(**{k: v for k, v in qwen2_cfg_dict.items() if k in qwen2_allowed_keys}) + text_model = AutoModelForCausalLM.from_config(qwen2_cfg, trust_remote_code=True, **MODEL_KWARGS) + text_model = text_model.to(torch.float32) + text_model.eval() + qeff_text_model = QEFFAutoModelForCausalLM(text_model) + return _exported_onnx_path(qeff_text_model.export(out_dir / "text-fallback")) + + text_configs = [getattr(config, "text_config", None), getattr(config, "llm_config", None)] + for text_config in text_configs: + if text_config is None: + continue + try: + text_model = AutoModelForCausalLM.from_config( + text_config, + trust_remote_code=True, + **MODEL_KWARGS, + ) + text_model = text_model.to(torch.float32) + text_model.eval() + qeff_text_model = QEFFAutoModelForCausalLM(text_model) + return _exported_onnx_path(qeff_text_model.export(out_dir / "text-fallback")) + except Exception: + continue + raise RuntimeError(f"No text fallback config path available for {model_id}") + except Exception as text_exc: + _skip_on_model_fetch_error(text_exc, model_id) + except Exception as cfg_exc: + _skip_on_model_fetch_error(cfg_exc, model_id) + + +@pytest.mark.llm_model +@pytest.mark.parametrize( + ("model_type", "model_id"), + sorted(CAUSAL_RUNTIME_MODEL_IDS.items()), + ids=sorted(CAUSAL_RUNTIME_MODEL_IDS), +) +def test_causal_lm_cpu_runtime_parity_with_api_runner(model_type, model_id, tmp_path): + tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True) + if hasattr(tokenizer, "model_input_names"): + tokenizer.model_input_names = ["input_ids", "attention_mask"] + prompt = ["hello world"] + prompt_len = 8 + ctx_len = 12 + + model_hf = AutoModelForCausalLM.from_pretrained( + model_id, + **MODEL_KWARGS, + low_cpu_mem_usage=False, + trust_remote_code=True, + torch_dtype=torch.float32, + ) + model_hf.eval() + + api_runner = ApiRunner( + batch_size=1, + tokenizer=tokenizer, + config=model_hf.config, + prompt=prompt, + prompt_len=prompt_len, + ctx_len=ctx_len, + full_batch_size=None, + ) + + hf_tokens = api_runner.run_hf_model_on_pytorch(model_hf) + qeff_model = QEFFAutoModelForCausalLM(model_hf) + kv_tokens = api_runner.run_kv_model_on_pytorch(qeff_model.model) + onnx_path = _exported_onnx_path(qeff_model.export(tmp_path)) + ort_tokens = api_runner.run_kv_model_on_ort(str(onnx_path)) + + assert np.array_equal(hf_tokens, kv_tokens.squeeze(0)) + assert np.array_equal(kv_tokens, ort_tokens) + + +@pytest.mark.llm_model +def test_vlm_text_side_runtime_parity_and_full_export(tmp_path): + tokenizer = AutoTokenizer.from_pretrained(VLM_TEXT_RUNTIME_MODEL_ID, trust_remote_code=True) + config = AutoConfig.from_pretrained(VLM_TEXT_RUNTIME_MODEL_ID, trust_remote_code=True) + text_config = config.text_config + + text_model = AutoModelForCausalLM.from_config(text_config, trust_remote_code=True, **MODEL_KWARGS) + text_model.eval() + + api_runner = ApiRunner( + batch_size=1, + tokenizer=tokenizer, + config=text_model.config, + prompt=["hello world"], + prompt_len=4, + ctx_len=8, + full_batch_size=None, + ) + + hf_tokens = api_runner.run_hf_model_on_pytorch(text_model) + qeff_text_model = QEFFAutoModelForCausalLM(text_model) + kv_tokens = api_runner.run_kv_model_on_pytorch(qeff_text_model.model) + onnx_path = _exported_onnx_path(qeff_text_model.export(tmp_path / "vlm-text")) + ort_tokens = api_runner.run_kv_model_on_ort(str(onnx_path)) + + assert np.array_equal(hf_tokens, kv_tokens.squeeze(0)) + assert np.array_equal(kv_tokens, ort_tokens) + + vlm_model = QEFFAutoModelForImageTextToText.from_pretrained(VLM_TEXT_RUNTIME_MODEL_ID, trust_remote_code=True) + vlm_onnx_path = _exported_onnx_path(vlm_model.export(tmp_path / "vlm-full")) + assert vlm_onnx_path.name.endswith(".onnx") + + +@pytest.mark.llm_model +@pytest.mark.parametrize( + ("vlm_name", "model_id"), + sorted(VLM_EXPORT_MODEL_IDS.items()), + ids=sorted(VLM_EXPORT_MODEL_IDS), +) +def test_vlm_export_smoke_additional_models(vlm_name, model_id, tmp_path): + vlm_onnx_path = _export_vlm_with_text_fallback(model_id, tmp_path / f"vlm-{vlm_name}") + assert vlm_onnx_path.name.endswith(".onnx") + + +@pytest.mark.llm_model +def test_text_embedding_cpu_parity_and_export(tmp_path): + tokenizer = AutoTokenizer.from_pretrained(TINY_TEXT_EMBEDDING_MODEL_ID) + model_hf = AutoModel.from_pretrained(TINY_TEXT_EMBEDDING_MODEL_ID, **MODEL_KWARGS) + model_hf.eval() + + inputs = tokenizer("hello world", return_tensors="pt") + hf_outputs = model_hf(**inputs).last_hidden_state.detach().numpy() + + qeff_model = QEFFAutoModel(model_hf) + qeff_outputs = qeff_model.generate(inputs=inputs, runtime_ai100=False).last_hidden_state.detach().numpy() + onnx_path = _exported_onnx_path(qeff_model.export(tmp_path)) + ort_outputs = _run_embedding_ort(onnx_path, inputs) + + assert np.allclose(hf_outputs, qeff_outputs, atol=1e-5) + assert np.allclose(hf_outputs, ort_outputs, atol=1e-5) + + +@pytest.mark.llm_model +def test_audio_embedding_ctc_cpu_parity_and_export(tmp_path): + processor = AutoTokenizer.from_pretrained(TINY_AUDIO_CTC_MODEL_ID) + del processor + replace_transformers_quantizers() + model_hf = AutoModelForCTC.from_pretrained(TINY_AUDIO_CTC_MODEL_ID, **MODEL_KWARGS, low_cpu_mem_usage=False) + model_hf.eval() + + from transformers import AutoProcessor + + audio_processor = AutoProcessor.from_pretrained(TINY_AUDIO_CTC_MODEL_ID) + input_values = audio_processor( + np.zeros(400, dtype=np.float32), return_tensors="pt", sampling_rate=16000 + ).input_values + + hf_logits = model_hf(input_values=input_values).logits.detach().numpy() + qeff_model = QEFFAutoModelForCTC(model_hf, pretrained_model_name_or_path=TINY_AUDIO_CTC_MODEL_ID) + onnx_path = _exported_onnx_path(qeff_model.export(tmp_path)) + ort_session = _ort_session(onnx_path) + ort_logits = ort_session.run(None, {"input_values": input_values.detach().numpy()})[0] + + assert np.allclose(hf_logits, ort_logits, atol=1e-5) + + +@pytest.mark.llm_model +def test_seq_classification_cpu_parity_and_export(tmp_path): + tokenizer = AutoTokenizer.from_pretrained(TINY_SEQ_CLASSIFICATION_MODEL_ID, trust_remote_code=True) + model_hf = AutoModelForSequenceClassification.from_pretrained( + TINY_SEQ_CLASSIFICATION_MODEL_ID, + trust_remote_code=True, + ) + model_hf.eval() + + inputs = tokenizer("quick classification check", return_tensors="pt") + hf_logits = model_hf(**inputs).logits.detach().numpy() + + qeff_model = QEFFAutoModelForSequenceClassification(model_hf) + qeff_logits = qeff_model.model(**inputs).logits.detach().numpy() + onnx_path = _exported_onnx_path(qeff_model.export(tmp_path)) + ort_session = _ort_session(onnx_path) + input_names = {item.name for item in ort_session.get_inputs()} + ort_logits = ort_session.run( + None, + {name: tensor.detach().numpy() for name, tensor in inputs.items() if name in input_names}, + )[0] + + assert np.allclose(hf_logits, qeff_logits, atol=1e-5) + assert np.allclose(hf_logits, ort_logits, atol=1e-5) + + +@pytest.mark.llm_model +def test_whisper_export_smoke(tmp_path): + model_hf = AutoModelForSpeechSeq2Seq.from_pretrained( + TINY_WHISPER_MODEL_ID, + **MODEL_KWARGS, + low_cpu_mem_usage=False, + ) + model_hf.eval() + + qeff_model = QEFFAutoModelForSpeechSeq2Seq(model_hf, pretrained_model_name_or_path=TINY_WHISPER_MODEL_ID) + onnx_path = _run_whisper_export_smoke(qeff_model, tmp_path / "whisper") + + assert onnx_path.name.endswith(".onnx") + + +@pytest.mark.llm_model +def test_causal_subfunction_export_smoke(tmp_path): + model_id = CAUSAL_RUNTIME_MODEL_IDS["gpt2"] + model_hf = AutoModelForCausalLM.from_pretrained(model_id, **MODEL_KWARGS, low_cpu_mem_usage=False) + model_hf.eval() + qeff_model = QEFFAutoModelForCausalLM(model_hf) + + with_subfunctions_path = _exported_onnx_path( + qeff_model.export(tmp_path / "with-subfunctions", use_onnx_subfunctions=True, offload_pt_weights=False) + ) + without_subfunctions_path = _exported_onnx_path( + qeff_model.export(tmp_path / "without-subfunctions", use_onnx_subfunctions=False) + ) + + with_subfunctions_model = onnx.load(with_subfunctions_path, load_external_data=False) + without_subfunctions_model = onnx.load(without_subfunctions_path, load_external_data=False) + with_names = [func.name for func in with_subfunctions_model.functions] + without_names = [func.name for func in without_subfunctions_model.functions] + assert any("QEffGPT2Block" in name for name in with_names) + assert not any("QEffGPT2Block" in name for name in without_names) + + +@pytest.mark.llm_model +def test_prefix_caching_continuous_batching_export_and_ort_smoke(tmp_path): + qeff_model = QEFFAutoModelForCausalLM.from_pretrained(PREFIX_CACHING_MODEL_ID, continuous_batching=True) + onnx_path = _exported_onnx_path(qeff_model.export(tmp_path / "prefix-caching")) + onnx_model = onnx.load(onnx_path, load_external_data=False) + + input_names = {inp.name for inp in onnx_model.graph.input} + output_names = {out.name for out in onnx_model.graph.output} + op_types = {node.op_type for node in onnx_model.graph.node} + assert "batch_index" in input_names + assert "CtxScatterCB" in op_types + assert "CtxGatherCB" in op_types + assert any(name.endswith("_RetainedState") for name in output_names) + + +@pytest.mark.llm_model +def test_awq_export_smoke(tmp_path): + replace_transformers_quantizers() + model_hf = AutoModelForCausalLM.from_pretrained(TINY_AWQ_MODEL_ID, low_cpu_mem_usage=False) + model_hf.eval() + + qeff_model = QEFFAutoModelForCausalLM(model_hf, pretrained_model_name_or_path=TINY_AWQ_MODEL_ID) + with _suppress_native_output(): + onnx_path = _exported_onnx_path(qeff_model.export(tmp_path)) + onnx_model = onnx.load(onnx_path, load_external_data=False) + + assert any(node.op_type == "MatMulNBits" for node in onnx_model.graph.node) diff --git a/tests/transformers/models/test_disagg_mode.py b/tests/transformers/models/test_disagg_mode.py index d11c4e397..3c5361f3e 100644 --- a/tests/transformers/models/test_disagg_mode.py +++ b/tests/transformers/models/test_disagg_mode.py @@ -16,8 +16,13 @@ from QEfficient.generation.cloud_infer import QAICInferenceSession from QEfficient.transformers.quantizers import replace_transformers_quantizers, undo_transformers_quantizers -model_id = "openai/gpt-oss-120b" # weights are not required to convert to fp32 - +# model id based on blocking support and chunking +model_id_blocking = [ + "openai/gpt-oss-20b", +] +model_id_chunking = [ + "Qwen/Qwen3-30B-A3B-Instruct-2507", +] prompt2 = """ Once upon a time, in a small town, there lived a young boy named Alex. Alex was a curious and adventurous child, always eager to explore the world around him. One day, while playing in the park, Alex stumbled upon a mysterious old book hidden beneath a pile of leaves. The book was filled with stories of distant lands, magical creatures, and extraordinary adventures. @@ -32,7 +37,7 @@ @pytest.mark.on_qaic @pytest.mark.llm_model -@pytest.mark.parametrize("model_id", [model_id]) +@pytest.mark.parametrize("model_id", model_id_blocking) @pytest.mark.parametrize("prompt", prompts) def test_disagg_mode_prefill(model_id, prompt): # Run prefill @@ -93,7 +98,7 @@ def test_disagg_mode_prefill(model_id, prompt): ) prefill_session = QAICInferenceSession(prefill_qpc_path) - logits_out_placeholder = np.zeros((1, 1, 201088), dtype=np.float32) + logits_out_placeholder = np.zeros((1, 1, config.vocab_size), dtype=np.float32) prefill_session.set_buffers({"logits": logits_out_placeholder}) inputs.pop("past_key_values") inputs = {k: v.detach().numpy() for k, v in inputs.items()} @@ -105,10 +110,9 @@ def test_disagg_mode_prefill(model_id, prompt): assert (torch.from_numpy(qpc_out["logits"]) - qeff_out.logits).abs().max() < 5e-2 -@pytest.mark.skip(reason="no way of currently testing this without the assert sdk") @pytest.mark.on_qaic @pytest.mark.llm_model -@pytest.mark.parametrize("model_id", [model_id]) +@pytest.mark.parametrize("model_id", model_id_chunking) @pytest.mark.parametrize("prompt", prompts) def test_disagg_mode_prefill_chunked(model_id, prompt): # Run prefill @@ -143,7 +147,7 @@ def test_disagg_mode_prefill_chunked(model_id, prompt): past_key_values = [] for i in range(config.num_hidden_layers): cache_len = CTX_LEN - pad_shape = (1, 8, cache_len, 64) + pad_shape = (1, config.num_key_value_heads, cache_len, config.head_dim) past_key = torch.zeros((pad_shape), dtype=torch.float32) past_value = torch.zeros((pad_shape), dtype=torch.float32) pkv = (past_key, past_value) @@ -178,7 +182,7 @@ def test_disagg_mode_prefill_chunked(model_id, prompt): prefill_session.skip_buffers( [x for x in prefill_session.input_names + prefill_session.output_names if x.startswith("past_")] ) - logits_out_placeholder = np.zeros((1, 1, 201088), dtype=np.float32) + logits_out_placeholder = np.zeros((1, 1, config.vocab_size), dtype=np.float32) prefill_session.set_buffers({"logits": logits_out_placeholder}) inputs.pop("past_key_values") inputs = {k: v.detach().numpy() for k, v in inputs.items()} @@ -192,3 +196,309 @@ def test_disagg_mode_prefill_chunked(model_id, prompt): del prefill_session # Check QAIC output isclose with QEFF pytorch output assert (torch.from_numpy(qpc_out["logits"]) - qeff_out.logits).abs().max() < 8e-2 + + +@pytest.mark.on_qaic +@pytest.mark.parametrize("model_id", model_id_blocking) +@pytest.mark.parametrize("prompt", [prompt1]) +def test_disagg_mode_prefill_only_and_decode_only(model_id, prompt): + # Run prefill for original pytorch model + tokenizer = AutoTokenizer.from_pretrained(model_id) + PREFILL_SEQ_LEN = 256 + CTX_LEN = 256 + inputs = tokenizer(prompt, return_tensors="np", padding=True) + padded_len = inputs["input_ids"].shape[1] + num_chunks = -(padded_len // -PREFILL_SEQ_LEN) # ceil divide without float + padded_len = num_chunks * PREFILL_SEQ_LEN # Convert to a multiple of prompt_len + + replace_transformers_quantizers() + model = AutoModelForCausalLM.from_pretrained(model_id, num_hidden_layers=2) + config = model.config + inputs = tokenizer(prompt, return_tensors="np", padding="max_length", max_length=padded_len) + inputs["position_ids"] = np.where(inputs.pop("attention_mask"), np.arange(padded_len), -1) + inputs.pop("token_type_ids", None) + inputs = {k: torch.from_numpy(v).to(model.device) for k, v in inputs.items()} + cache = HybridCache(config=config, batch_size=1, max_cache_len=CTX_LEN) + ins = tokenizer(prompt, return_tensors="pt") + orig_out = model(**ins, past_key_values=cache) + + position_ids = inputs["position_ids"] + generated_ids = [] + generation_len = 10 + out = orig_out + for _ in range(1, generation_len): + next_token_id = out["logits"][:, -1, :].argmax(-1).reshape(-1, 1) + generated_ids.append(next_token_id) + position_ids = position_ids.max(1, keepdim=True).values + 1 + decode_inputs = { + "input_ids": next_token_id, + "position_ids": position_ids, + "past_key_values": out["past_key_values"], + } + out = model(**decode_inputs) + + generated_ids.append(out["logits"][:, -1, :].argmax(-1).reshape(-1, 1)) + generated_ids = np.concatenate(generated_ids, axis=1) + predicted_string = tokenizer.batch_decode(generated_ids, skip_special_tokens=True) + print("Original HF Model Outputs (Torch CPU): \n") + print("Prompt:", repr(prompt)) + print("Completion:", repr(predicted_string)) + + undo_transformers_quantizers() + + prefill_qeff_model = QEFFAutoModelForCausalLM.from_pretrained(model_id, num_hidden_layers=2) + prefill_qeff_model.prefill(enable=True) + config = prefill_qeff_model.model.config + past_key_values = [] + for i in range(config.num_hidden_layers): + cache_len = 128 if i % 2 == 0 else PREFILL_SEQ_LEN + pad_shape = (1, 8, cache_len, 64) + past_key = torch.zeros((pad_shape), dtype=torch.float32) + past_value = torch.zeros((pad_shape), dtype=torch.float32) + pkv = (past_key, past_value) + past_key_values.append(pkv) + inputs["past_key_values"] = past_key_values + + prefill_qeff_out = prefill_qeff_model.model(**inputs) + + # Check our pytorch implementation + assert (prefill_qeff_out.logits - orig_out.logits[:, -1, :]).abs().max() < 1e-4 + + decode_qeff_model = QEFFAutoModelForCausalLM.from_pretrained(model_id, num_hidden_layers=2) + decode_qeff_model.prefill(enable=False) + qeff_out = prefill_qeff_out + + position_ids = inputs["position_ids"] + qeff_generated_ids = [] + for _ in range(1, generation_len): + next_token_id = qeff_out["logits"][:, -1, :].argmax(-1).reshape(-1, 1) + qeff_generated_ids.append(next_token_id) + position_ids = position_ids.max(1, keepdim=True).values + 1 + decode_inputs = { + "input_ids": next_token_id, + "position_ids": position_ids, + "past_key_values": qeff_out["past_key_values"], + } + qeff_out = decode_qeff_model.model(**decode_inputs) + + qeff_generated_ids.append(out["logits"][:, -1, :].argmax(-1).reshape(-1, 1)) + qeff_generated_ids = np.concatenate(qeff_generated_ids, axis=1) + predicted_string = tokenizer.batch_decode(qeff_generated_ids, skip_special_tokens=True) + print("QEFF Transformed Model Outputs (Torch CPU): \n") + print("Prompt:", repr(prompt)) + print("Completion:", repr(predicted_string)) + + assert (qeff_generated_ids == generated_ids).all() + + prefill_qpc_path = prefill_qeff_model.compile( + prefill_seq_len=PREFILL_SEQ_LEN, + ctx_len=CTX_LEN, + num_cores=16, + mxfp6_matmul=False, + mxint8_kv_cache=False, + num_devices=1, + mos=1, + aic_enable_depth_first=True, + num_speculative_tokens=None, + prefill_only=True, + ) + + prefill_session = QAICInferenceSession(prefill_qpc_path) + logits_out_placeholder = np.zeros((1, 1, config.vocab_size), dtype=np.float32) + prefill_session.set_buffers({"logits": logits_out_placeholder}) + inputs.pop("past_key_values") + inputs = {k: v.detach().numpy() for k, v in inputs.items()} + qpc_out = prefill_session.run(inputs) + del prefill_session + # Check QAIC output isclose with QEFF pytorch output + assert (torch.from_numpy(qpc_out["logits"]) - prefill_qeff_out.logits).abs().max() < 5e-2 + + decode_qpc_path = decode_qeff_model.compile( + prefill_seq_len=1, + ctx_len=CTX_LEN, + num_cores=16, + mxfp6_matmul=False, + mxint8_kv_cache=False, + num_devices=1, + mos=1, + aic_enable_depth_first=True, + num_speculative_tokens=None, + offload_pt_weights=False, # Need the weights in memory for prefill-model export/compilation in the next step + ) + + qpc_outputs = [] + decode_session = QAICInferenceSession(decode_qpc_path) + decode_session.set_buffers({"logits": logits_out_placeholder}) + + decode_inputs = { + "input_ids": np.argmax(qpc_out["logits"]).reshape(1, 1), + "position_ids": np.max(inputs["position_ids"]).reshape(1, 1) + 1, + } + + qpc_outputs.append(decode_inputs["input_ids"][0][0]) + for i in range(config.num_hidden_layers): + if i % 2 == 0 and decode_inputs["position_ids"] >= config.sliding_window: + k = qpc_out[f"past_key.{i}_RetainedState"] + v = qpc_out[f"past_value.{i}_RetainedState"] + mod_pos_id = config.sliding_window - decode_inputs["position_ids"][0][0] % config.sliding_window + decode_inputs[f"past_key.{i}"] = np.concatenate((k[:, :, mod_pos_id:, :], k[:, :, :mod_pos_id, :]), axis=-2) + decode_inputs[f"past_value.{i}"] = np.concatenate( + (v[:, :, mod_pos_id:, :], v[:, :, :mod_pos_id, :]), axis=-2 + ) + else: + decode_inputs[f"past_key.{i}"] = qpc_out[f"past_key.{i}_RetainedState"] + decode_inputs[f"past_value.{i}"] = qpc_out[f"past_value.{i}_RetainedState"] + + decode_out = decode_session.run(decode_inputs) + decode_session.skip_buffers( + [x for x in decode_session.input_names + decode_session.output_names if x.startswith("past_")] + ) + pos_id = np.max(decode_inputs["position_ids"]).reshape(1, 1) + 1 + for i in range(generation_len - 1): + loop_decode_inputs = { + "input_ids": np.argmax(decode_out["logits"]).reshape(1, 1), + "position_ids": pos_id, + } + qpc_outputs.append(loop_decode_inputs["input_ids"][0][0]) + decode_out = decode_session.run(loop_decode_inputs) + pos_id += 1 + + print("QPC Outputs (AIC): \n") + print("Prompt:", repr(prompt)) + print("Completion:", repr(tokenizer.decode(qpc_outputs))) + assert (qeff_generated_ids == qpc_outputs).all() + + +@pytest.mark.on_qaic +@pytest.mark.parametrize("model_id", model_id_blocking) +@pytest.mark.parametrize("prompt", [prompt1]) +def test_disagg_mode_prefix_caching(model_id, prompt): + PREFILL_SEQ_LEN = 128 + CTX_LEN = 128 * 3 + config = AutoConfig.from_pretrained(model_id, num_hidden_layers=2) + prefill_qeff_model = QEFFAutoModelForCausalLM.from_pretrained( + model_id, num_hidden_layers=2, continuous_batching=True + ) + prefill_qeff_model.prefill(enable=True, enable_chunking=True) + prefill_qpc_path = prefill_qeff_model.compile( + prefill_seq_len=PREFILL_SEQ_LEN, + ctx_len=CTX_LEN, + num_cores=16, + mxfp6_matmul=False, + mxint8_kv_cache=False, + num_devices=1, + mos=1, + aic_enable_depth_first=True, + num_speculative_tokens=None, + prefill_only=True, + enable_chunking=True, + full_batch_size=1, + kv_cache_batch_size=2, + ) + + decode_qeff_model = QEFFAutoModelForCausalLM.from_pretrained( + model_id, num_hidden_layers=2, continuous_batching=True + ) + decode_qeff_model.prefill(enable=False) + decode_qpc_path = decode_qeff_model.compile( + prefill_seq_len=1, + ctx_len=CTX_LEN, + num_cores=16, + mxfp6_matmul=False, + mxint8_kv_cache=False, + num_devices=1, + mos=1, + aic_enable_depth_first=True, + num_speculative_tokens=None, + offload_pt_weights=False, # Need the weights in memory for prefill-model export/compilation in the next step + full_batch_size=1, + kv_cache_batch_size=2, + retain_full_kv=True, + ) + + out1, ids1 = prefix_caching_inference(model_id, prefill_qpc_path, decode_qpc_path, prompt, decode_batch_id=0) + out2, ids2 = prefix_caching_inference(model_id, prefill_qpc_path, decode_qpc_path, prompt, decode_batch_id=1) + + for i in range(config.num_hidden_layers): + assert ( + np.abs( + out1[f"past_key.{i}_RetainedState"][0, :, :, :] - out2[f"past_key.{i}_RetainedState"][1, :, :, :] + ).max() + < 5e-2 + ) + assert ( + np.abs( + out1[f"past_value.{i}_RetainedState"][0, :, :, :] - out2[f"past_value.{i}_RetainedState"][1, :, :, :] + ).max() + < 5e-2 + ) + + +def prefix_caching_inference(model_id, prefill_qpc_path, decode_qpc_path, prompt, decode_batch_id): + PREFILL_SEQ_LEN = 128 + tokenizer = AutoTokenizer.from_pretrained(model_id) + config = AutoConfig.from_pretrained(model_id, num_hidden_layers=2) + inputs = tokenizer(prompt, return_tensors="np", padding=True) + padded_len = inputs["input_ids"].shape[1] + num_chunks = -(padded_len // -PREFILL_SEQ_LEN) # ceil divide without float + padded_len = num_chunks * PREFILL_SEQ_LEN # Convert to a multiple of prompt_len + + inputs = tokenizer(prompt, return_tensors="np", padding="max_length", max_length=padded_len) + inputs["position_ids"] = np.where(inputs.pop("attention_mask"), np.arange(padded_len), -1) + inputs.pop("token_type_ids", None) + inputs["batch_index"] = np.array([[decode_batch_id]], dtype=np.int64) + + prefill_session = QAICInferenceSession(prefill_qpc_path) + logits_out_placeholder = np.zeros((1, 1, config.vocab_size), dtype=np.float32) + prefill_session.set_buffers({"logits": logits_out_placeholder}) + for i in range(num_chunks): + chunk_inputs = inputs.copy() + chunk_inputs["input_ids"] = inputs["input_ids"][:, i * PREFILL_SEQ_LEN : (i + 1) * PREFILL_SEQ_LEN] + chunk_inputs["position_ids"] = inputs["position_ids"][:, i * PREFILL_SEQ_LEN : (i + 1) * PREFILL_SEQ_LEN] + qpc_out = prefill_session.run(chunk_inputs) + del prefill_session + + qpc_outputs = [] + decode_inputs = { + "input_ids": np.argmax(qpc_out["logits"]).reshape(1, 1), + "position_ids": np.max(inputs["position_ids"]).reshape(1, 1) + 1, + "batch_index": inputs["batch_index"], + } + qpc_outputs.append(decode_inputs["input_ids"][0][0]) + + decode_session = QAICInferenceSession(decode_qpc_path) + decode_session.set_buffers({"logits": logits_out_placeholder}) + generation_len = 5 + + for i in range(config.num_hidden_layers): + if i % 2 == 0 and decode_inputs["position_ids"] >= config.sliding_window: + k = qpc_out[f"past_key.{i}_RetainedState"] + v = qpc_out[f"past_value.{i}_RetainedState"] + mod_pos_id = config.sliding_window - decode_inputs["position_ids"][0][0] % config.sliding_window + decode_inputs[f"past_key.{i}"] = np.concatenate((k[:, :, mod_pos_id:, :], k[:, :, :mod_pos_id, :]), axis=-2) + decode_inputs[f"past_value.{i}"] = np.concatenate( + (v[:, :, mod_pos_id:, :], v[:, :, :mod_pos_id, :]), axis=-2 + ) + else: + decode_inputs[f"past_key.{i}"] = qpc_out[f"past_key.{i}_RetainedState"] + decode_inputs[f"past_value.{i}"] = qpc_out[f"past_value.{i}_RetainedState"] + + decode_out = decode_session.run(decode_inputs) + pos_id = np.max(decode_inputs["position_ids"]).reshape(1, 1) + 1 + for i in range(generation_len - 1): + loop_decode_inputs = { + "input_ids": np.argmax(decode_out["logits"]).reshape(1, 1), + "position_ids": pos_id, + "batch_index": inputs["batch_index"], + } + for i in range(config.num_hidden_layers): + loop_decode_inputs[f"past_key.{i}"] = decode_out[f"past_key.{i}_RetainedState"] + loop_decode_inputs[f"past_value.{i}"] = decode_out[f"past_value.{i}_RetainedState"] + qpc_outputs.append(loop_decode_inputs["input_ids"][0][0]) + decode_out = decode_session.run(loop_decode_inputs) + pos_id += 1 + + print("QPC Outputs (AIC): \n") + print("Prompt:", repr(prompt)) + print("Completion:", repr(tokenizer.decode(qpc_outputs))) + return qpc_out, qpc_outputs diff --git a/tests/unit_test/__init__.py b/tests/unit_test/__init__.py new file mode 100644 index 000000000..d647b73a6 --- /dev/null +++ b/tests/unit_test/__init__.py @@ -0,0 +1,6 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +# SPDX-License-Identifier: BSD-3-Clause +# +# ----------------------------------------------------------------------------- diff --git a/tests/unit_test/conftest.py b/tests/unit_test/conftest.py new file mode 100644 index 000000000..3b73aff26 --- /dev/null +++ b/tests/unit_test/conftest.py @@ -0,0 +1,62 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +# SPDX-License-Identifier: BSD-3-Clause +# +# ----------------------------------------------------------------------------- +""" +Shared fixtures and configuration for QEfficient unit_test tests. + +CPU-only tests that do NOT require QAIC hardware. +Run with: pytest tests/unit_test/ -n auto -v +""" + +import pytest +import torch + + +def pytest_configure(config): + """Register custom markers for unit_test tests.""" + config.addinivalue_line("markers", "cpu_only: CPU-only test (no QAIC hardware required)") + config.addinivalue_line("markers", "slow: slow test (ONNX export, model loading)") + config.addinivalue_line("markers", "accuracy: accuracy test (numerical comparison between stages)") + config.addinivalue_line("markers", "causal_lm: CausalLM model test") + config.addinivalue_line("markers", "seq_classification: SeqClassification model test") + config.addinivalue_line("markers", "embedding: Embedding model test") + config.addinivalue_line("markers", "speech: Speech Seq2Seq model test") + config.addinivalue_line("markers", "transforms: PyTorch transform test") + config.addinivalue_line("markers", "cache: Cache utility test") + config.addinivalue_line("markers", "onnx: ONNX export/ORT test") + config.addinivalue_line("markers", "input_handler: InputHandler utility test") + config.addinivalue_line("markers", "diffusers: QEfficient diffusers module test") + + +def pytest_collection_modifyitems(items): + """Auto-add cpu_only marker to all tests in this directory.""" + for item in items: + if "tests/unit_test" in str(item.fspath): + item.add_marker(pytest.mark.cpu_only) + + +@pytest.fixture(autouse=True) +def set_cpu_threads(): + """Limit CPU threads per worker to avoid contention in parallel runs.""" + original = torch.get_num_threads() + torch.set_num_threads(min(4, original)) + yield + torch.set_num_threads(original) + + +@pytest.fixture(autouse=True) +def set_deterministic_seed(): + """Set random seed for reproducibility across all tests.""" + torch.manual_seed(42) + yield + + +@pytest.fixture +def tmp_export_dir(tmp_path): + """Provide a temporary directory for ONNX exports (unique per test).""" + export_dir = tmp_path / "qeff_exports" + export_dir.mkdir(parents=True, exist_ok=True) + yield export_dir diff --git a/tests/unit_test/e2e/__init__.py b/tests/unit_test/e2e/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/unit_test/e2e/test_embedding_e2e.py b/tests/unit_test/e2e/test_embedding_e2e.py new file mode 100644 index 000000000..0c7558fe0 --- /dev/null +++ b/tests/unit_test/e2e/test_embedding_e2e.py @@ -0,0 +1,336 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +# SPDX-License-Identifier: BSD-3-Clause +# +# ----------------------------------------------------------------------------- +""" +End-to-end accuracy tests for Embedding models: HF → QEff (PoolingTransform) → ORT. + +BERT embeddings have no Qualcomm custom ops, so the full ORT pipeline works. +Key accuracy assertions: + - HF and QEff produce numerically identical hidden states + - PooledModel (mean/cls) produces correct embedding shapes + - ORT embeddings match QEff PyTorch embeddings + +Models: BertModel (mean pooling, cls pooling) +All tests run on CPU only. +""" + +import numpy as np +import pytest +import torch +from transformers import BertConfig, BertModel + +from QEfficient.transformers.models.modeling_auto import QEFFAutoModel +from QEfficient.transformers.models.pytorch_transforms import PoolingTransform + +SEQ_LEN = 16 +VOCAB_SIZE = 500 +HIDDEN_SIZE = 64 + + +def make_tiny_bert(): + cfg = BertConfig( + num_hidden_layers=1, + num_attention_heads=2, + hidden_size=HIDDEN_SIZE, + intermediate_size=128, + vocab_size=VOCAB_SIZE, + max_position_embeddings=64, + ) + return BertModel(cfg).eval(), cfg + + +def make_inputs(batch=1, seq=SEQ_LEN): + return { + "input_ids": torch.randint(0, VOCAB_SIZE, (batch, seq)), + "attention_mask": torch.ones(batch, seq, dtype=torch.long), + } + + +@pytest.mark.embedding +class TestHFEmbeddingBaseline: + """HF BERT embedding model produces correct hidden states.""" + + def test_bert_last_hidden_state_shape(self): + model, cfg = make_tiny_bert() + with torch.no_grad(): + out = model(**make_inputs()) + assert out.last_hidden_state.shape == (1, SEQ_LEN, HIDDEN_SIZE) + + def test_bert_pooler_output_shape(self): + model, cfg = make_tiny_bert() + with torch.no_grad(): + out = model(**make_inputs()) + assert out.pooler_output.shape == (1, HIDDEN_SIZE) + + def test_bert_hidden_states_are_finite(self): + model, cfg = make_tiny_bert() + with torch.no_grad(): + out = model(**make_inputs()) + assert torch.isfinite(out.last_hidden_state).all() + + def test_bert_batch_hidden_state_shape(self): + model, cfg = make_tiny_bert() + with torch.no_grad(): + out = model(**make_inputs(batch=4)) + assert out.last_hidden_state.shape == (4, SEQ_LEN, HIDDEN_SIZE) + + def test_bert_mean_pooling_shape(self): + model, cfg = make_tiny_bert() + inputs = make_inputs() + with torch.no_grad(): + out = model(**inputs) + mask = inputs["attention_mask"].unsqueeze(-1).float() + mean_emb = (out.last_hidden_state * mask).sum(1) / mask.sum(1) + assert mean_emb.shape == (1, HIDDEN_SIZE) + + +@pytest.mark.embedding +@pytest.mark.accuracy +class TestPoolingTransformAccuracy: + """PoolingTransform must produce embeddings consistent with HF hidden states.""" + + def test_mean_pooled_embedding_shape(self): + model, cfg = make_tiny_bert() + pooled, _ = PoolingTransform.apply(model, pooling="mean") + with torch.no_grad(): + emb = pooled(**make_inputs()) + assert emb.shape == (1, HIDDEN_SIZE) + + def test_cls_pooled_embedding_shape(self): + model, cfg = make_tiny_bert() + pooled, _ = PoolingTransform.apply(model, pooling="cls") + with torch.no_grad(): + emb = pooled(**make_inputs()) + assert emb.shape == (1, HIDDEN_SIZE) + + def test_mean_pooled_embedding_matches_manual_mean_pool(self): + """PooledModel mean output must match manually computed mean pooling.""" + model, cfg = make_tiny_bert() + inputs = make_inputs() + with torch.no_grad(): + hf_out = model(**inputs) + mask = inputs["attention_mask"].unsqueeze(-1).float() + manual_mean = (hf_out.last_hidden_state * mask).sum(1) / mask.sum(1) + + pooled, _ = PoolingTransform.apply(model, pooling="mean") + with torch.no_grad(): + pooled_mean = pooled(**inputs) + + max_diff = (manual_mean - pooled_mean).abs().max().item() + assert max_diff < 1e-5, f"Mean pooling mismatch: max_diff={max_diff:.2e}" + + def test_cls_pooled_embedding_matches_first_token(self): + """PooledModel CLS output must match the first token hidden state.""" + model, cfg = make_tiny_bert() + inputs = make_inputs() + with torch.no_grad(): + hf_out = model(**inputs) + cls_token = hf_out.last_hidden_state[:, 0, :] + + pooled, _ = PoolingTransform.apply(model, pooling="cls") + with torch.no_grad(): + pooled_cls = pooled(**inputs) + + max_diff = (cls_token - pooled_cls).abs().max().item() + assert max_diff < 1e-5, f"CLS pooling mismatch: max_diff={max_diff:.2e}" + + def test_mean_pooled_embeddings_are_finite(self): + model, cfg = make_tiny_bert() + pooled, _ = PoolingTransform.apply(model, pooling="mean") + with torch.no_grad(): + emb = pooled(**make_inputs()) + assert torch.isfinite(emb).all() + + def test_mean_pooled_batch_shape(self): + model, cfg = make_tiny_bert() + pooled, _ = PoolingTransform.apply(model, pooling="mean") + with torch.no_grad(): + emb = pooled(**make_inputs(batch=4)) + assert emb.shape == (4, HIDDEN_SIZE) + + def test_cosine_similarity_between_different_inputs_is_in_range(self): + model, cfg = make_tiny_bert() + pooled, _ = PoolingTransform.apply(model, pooling="mean") + with torch.no_grad(): + emb1 = pooled(**make_inputs()) + emb2 = pooled(**make_inputs()) + cos_sim = torch.nn.functional.cosine_similarity(emb1, emb2).item() + assert -1.0 <= cos_sim <= 1.0, f"Cosine similarity out of range: {cos_sim}" + + def test_same_input_produces_identical_embeddings(self): + model, cfg = make_tiny_bert() + pooled, _ = PoolingTransform.apply(model, pooling="mean") + inputs = make_inputs() + with torch.no_grad(): + emb1 = pooled(**inputs) + emb2 = pooled(**inputs) + assert torch.allclose(emb1, emb2), "Same input must produce identical embeddings" + + def test_qeff_auto_model_wraps_bert(self): + model, cfg = make_tiny_bert() + qeff_model = QEFFAutoModel(model) + assert qeff_model is not None + assert hasattr(qeff_model, "model") + + def test_qeff_auto_model_forward_returns_output(self): + model, cfg = make_tiny_bert() + qeff_model = QEFFAutoModel(model) + with torch.no_grad(): + out = qeff_model.model(**make_inputs()) + assert out is not None + + def test_mean_and_cls_embeddings_differ(self): + """Mean pooling and CLS pooling must produce different embeddings.""" + model, cfg = make_tiny_bert() + inputs = make_inputs() + + pooled_mean, _ = PoolingTransform.apply(model, pooling="mean") + with torch.no_grad(): + emb_mean = pooled_mean(**inputs) + + # Re-create model for CLS (transform is in-place) + model2, _ = make_tiny_bert() + # Copy weights + model2.load_state_dict(model.state_dict()) + pooled_cls, _ = PoolingTransform.apply(model2, pooling="cls") + with torch.no_grad(): + emb_cls = pooled_cls(**inputs) + + # They should generally differ (unless all tokens are identical) + # Just check they're both valid shapes + assert emb_mean.shape == emb_cls.shape == (1, HIDDEN_SIZE) + + +@pytest.mark.embedding +@pytest.mark.accuracy +@pytest.mark.onnx +@pytest.mark.slow +class TestEmbeddingORTAccuracy: + """Full pipeline: HF → QEff (PoolingTransform) → ORT.""" + + def test_bert_onnx_export_succeeds(self, tmp_export_dir): + import os + + model, cfg = make_tiny_bert() + qeff_model = QEFFAutoModel(model) + onnx_path = qeff_model.export(export_dir=str(tmp_export_dir)) + assert onnx_path is not None + assert os.path.exists(str(onnx_path)) + + def test_bert_onnx_passes_checker(self, tmp_export_dir): + import onnx + + model, cfg = make_tiny_bert() + qeff_model = QEFFAutoModel(model) + onnx_path = qeff_model.export(export_dir=str(tmp_export_dir)) + onnx_model = onnx.load(str(onnx_path)) + onnx.checker.check_model(onnx_model) + + def test_bert_ort_hidden_states_match_qeff(self, tmp_export_dir): + """ORT hidden states must match QEff PyTorch hidden states.""" + import onnxruntime as ort + + model, cfg = make_tiny_bert() + qeff_model = QEFFAutoModel(model) + inputs = make_inputs() + + with torch.no_grad(): + pt_out = qeff_model.model(**inputs) + pt_hidden = pt_out.last_hidden_state.numpy() if hasattr(pt_out, "last_hidden_state") else pt_out[0].numpy() + + onnx_path = qeff_model.export(export_dir=str(tmp_export_dir)) + session = ort.InferenceSession(str(onnx_path), providers=["CPUExecutionProvider"]) + ort_inputs = {k: v.numpy() for k, v in inputs.items()} + output_names = [o.name for o in session.get_outputs()] + ort_out = dict(zip(output_names, session.run(output_names, ort_inputs))) + + ort_hidden = None + for name, val in ort_out.items(): + if val.shape == pt_hidden.shape: + ort_hidden = val + break + + assert ort_hidden is not None, ( + f"No ORT output matches PT hidden state shape {pt_hidden.shape}. " + f"ORT outputs: {[(k, v.shape) for k, v in ort_out.items()]}" + ) + max_diff = np.abs(pt_hidden - ort_hidden).max() + assert max_diff < 1e-4, f"Hidden state max diff QEff vs ORT: {max_diff:.2e}. Must be < 1e-4." + + def test_bert_ort_output_shape_correct(self, tmp_export_dir): + """ORT BERT output must have correct shape.""" + import onnxruntime as ort + + model, cfg = make_tiny_bert() + qeff_model = QEFFAutoModel(model) + onnx_path = qeff_model.export(export_dir=str(tmp_export_dir)) + session = ort.InferenceSession(str(onnx_path), providers=["CPUExecutionProvider"]) + ort_inputs = {k: v.numpy() for k, v in make_inputs().items()} + output_names = [o.name for o in session.get_outputs()] + ort_out = dict(zip(output_names, session.run(output_names, ort_inputs))) + assert any(v.shape[0] == 1 for v in ort_out.values()), ( + f"No ORT output has batch dim=1. Outputs: {[(k, v.shape) for k, v in ort_out.items()]}" + ) + + def test_bert_ort_batch_hidden_states_match_qeff(self, tmp_export_dir): + """ORT batch hidden states must match QEff PyTorch for batch_size=4.""" + import onnxruntime as ort + + batch_size = 4 + model, cfg = make_tiny_bert() + qeff_model = QEFFAutoModel(model) + inputs = make_inputs(batch=batch_size) + + with torch.no_grad(): + pt_out = qeff_model.model(**inputs) + pt_hidden = pt_out.last_hidden_state.numpy() if hasattr(pt_out, "last_hidden_state") else pt_out[0].numpy() + + onnx_path = qeff_model.export(export_dir=str(tmp_export_dir)) + session = ort.InferenceSession(str(onnx_path), providers=["CPUExecutionProvider"]) + ort_inputs = {k: v.numpy() for k, v in inputs.items()} + output_names = [o.name for o in session.get_outputs()] + ort_out = dict(zip(output_names, session.run(output_names, ort_inputs))) + + ort_hidden = None + for name, val in ort_out.items(): + if val.shape == pt_hidden.shape: + ort_hidden = val + break + + if ort_hidden is not None: + max_diff = np.abs(pt_hidden - ort_hidden).max() + assert max_diff < 1e-4, f"Batch hidden state max diff: {max_diff:.2e}. Must be < 1e-4." + + def test_bert_ort_mean_pooled_embedding_matches_qeff(self, tmp_export_dir): + """ORT mean-pooled embedding argmax must match QEff PyTorch.""" + import onnxruntime as ort + + model, cfg = make_tiny_bert() + qeff_model = QEFFAutoModel(model) + inputs = make_inputs() + + with torch.no_grad(): + pt_out = qeff_model.model(**inputs) + pt_hidden = pt_out.last_hidden_state.numpy() if hasattr(pt_out, "last_hidden_state") else pt_out[0].numpy() + pt_mean = pt_hidden.mean(axis=1) + + onnx_path = qeff_model.export(export_dir=str(tmp_export_dir)) + session = ort.InferenceSession(str(onnx_path), providers=["CPUExecutionProvider"]) + ort_inputs = {k: v.numpy() for k, v in inputs.items()} + output_names = [o.name for o in session.get_outputs()] + ort_out = dict(zip(output_names, session.run(output_names, ort_inputs))) + + ort_hidden = None + for name, val in ort_out.items(): + if val.shape == pt_hidden.shape: + ort_hidden = val + break + + if ort_hidden is not None: + ort_mean = ort_hidden.mean(axis=1) + pt_top = int(pt_mean.argmax(-1)) + ort_top = int(ort_mean.argmax(-1)) + assert pt_top == ort_top, f"Mean-pooled embedding argmax mismatch: QEff={pt_top}, ORT={ort_top}" diff --git a/tests/unit_test/e2e/test_seq_classification_e2e.py b/tests/unit_test/e2e/test_seq_classification_e2e.py new file mode 100644 index 000000000..867f8beca --- /dev/null +++ b/tests/unit_test/e2e/test_seq_classification_e2e.py @@ -0,0 +1,301 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +# SPDX-License-Identifier: BSD-3-Clause +# +# ----------------------------------------------------------------------------- +""" +End-to-end accuracy tests for Sequence Classification: HF → QEff → ORT. + +BERT/DeBERTa have no Qualcomm custom ops, so the full pipeline works. +All three stages must predict the same class and produce numerically close logits. + +Models: BertForSequenceClassification, DebertaV2ForSequenceClassification +All tests run on CPU only. +""" + +import numpy as np +import pytest +import torch +from transformers import ( + BertConfig, + BertForSequenceClassification, + DebertaV2Config, + DebertaV2ForSequenceClassification, +) + +from QEfficient.transformers.models.modeling_auto import QEFFAutoModelForSequenceClassification + +SEQ_LEN = 16 +VOCAB_SIZE = 500 +NUM_LABELS = 3 + + +def make_tiny_bert(num_labels=NUM_LABELS): + cfg = BertConfig( + num_hidden_layers=1, + num_attention_heads=2, + hidden_size=64, + intermediate_size=128, + vocab_size=VOCAB_SIZE, + max_position_embeddings=64, + num_labels=num_labels, + ) + return BertForSequenceClassification(cfg).eval(), cfg + + +def make_tiny_deberta(num_labels=NUM_LABELS): + cfg = DebertaV2Config( + num_hidden_layers=1, + num_attention_heads=2, + hidden_size=64, + intermediate_size=128, + vocab_size=VOCAB_SIZE, + max_position_embeddings=64, + num_labels=num_labels, + type_vocab_size=0, + pos_att_type=["p2c", "c2p"], + ) + return DebertaV2ForSequenceClassification(cfg).eval(), cfg + + +def make_inputs(batch=1, seq=SEQ_LEN): + return { + "input_ids": torch.randint(0, VOCAB_SIZE, (batch, seq)), + "attention_mask": torch.ones(batch, seq, dtype=torch.long), + } + + +@pytest.mark.seq_classification +class TestHFSeqClassBaseline: + def test_bert_logits_shape(self): + model, cfg = make_tiny_bert() + with torch.no_grad(): + out = model(**make_inputs()) + assert out.logits.shape == (1, NUM_LABELS) + + def test_bert_batch_logits_shape(self): + model, cfg = make_tiny_bert() + with torch.no_grad(): + out = model(**make_inputs(batch=4)) + assert out.logits.shape == (4, NUM_LABELS) + + def test_bert_predicted_class_is_valid(self): + model, cfg = make_tiny_bert() + with torch.no_grad(): + pred = model(**make_inputs()).logits.argmax(-1).item() + assert 0 <= pred < NUM_LABELS + + def test_bert_logits_are_finite(self): + model, cfg = make_tiny_bert() + with torch.no_grad(): + logits = model(**make_inputs()).logits + assert torch.isfinite(logits).all() + + def test_bert_prediction_is_deterministic(self): + model, cfg = make_tiny_bert() + inputs = make_inputs() + with torch.no_grad(): + p1 = model(**inputs).logits.argmax(-1).item() + p2 = model(**inputs).logits.argmax(-1).item() + assert p1 == p2 + + def test_deberta_logits_shape(self): + try: + model, cfg = make_tiny_deberta() + with torch.no_grad(): + out = model(**make_inputs()) + assert out.logits.shape == (1, NUM_LABELS) + except Exception as e: + pytest.skip(f"DeBERTa-v2 not available: {e}") + + +@pytest.mark.seq_classification +@pytest.mark.accuracy +class TestQEffSeqClassAccuracyVsHF: + """QEff model must predict the same class as HF and produce numerically close logits.""" + + def test_bert_qeff_predicts_same_class_as_hf(self): + model, cfg = make_tiny_bert() + inputs = make_inputs() + with torch.no_grad(): + hf_class = model(**inputs).logits.argmax(-1).item() + qeff_model = QEFFAutoModelForSequenceClassification(model) + with torch.no_grad(): + qeff_class = qeff_model.model(**inputs).logits.argmax(-1).item() + assert hf_class == qeff_class, f"Class mismatch: HF={hf_class}, QEff={qeff_class}" + + def test_bert_qeff_logits_numerically_identical_to_hf(self): + model, cfg = make_tiny_bert() + inputs = make_inputs() + with torch.no_grad(): + hf_logits = model(**inputs).logits + qeff_model = QEFFAutoModelForSequenceClassification(model) + with torch.no_grad(): + qeff_logits = qeff_model.model(**inputs).logits + max_diff = (hf_logits - qeff_logits).abs().max().item() + assert max_diff < 1e-5, f"Logits differ by {max_diff:.2e}. Must be < 1e-5." + + def test_bert_qeff_logits_shape_correct(self): + model, cfg = make_tiny_bert() + qeff_model = QEFFAutoModelForSequenceClassification(model) + with torch.no_grad(): + logits = qeff_model.model(**make_inputs()).logits + assert logits.shape == (1, NUM_LABELS) + + def test_bert_qeff_logits_are_finite(self): + model, cfg = make_tiny_bert() + qeff_model = QEFFAutoModelForSequenceClassification(model) + with torch.no_grad(): + logits = qeff_model.model(**make_inputs()).logits + assert torch.isfinite(logits).all() + + def test_bert_qeff_batch_prediction_matches_hf(self): + model, cfg = make_tiny_bert() + inputs = make_inputs(batch=4) + with torch.no_grad(): + hf_classes = model(**inputs).logits.argmax(-1).tolist() + qeff_model = QEFFAutoModelForSequenceClassification(model) + with torch.no_grad(): + qeff_classes = qeff_model.model(**inputs).logits.argmax(-1).tolist() + assert hf_classes == qeff_classes, f"Batch class mismatch: HF={hf_classes}, QEff={qeff_classes}" + + def test_deberta_qeff_predicts_same_class_as_hf(self): + try: + model, cfg = make_tiny_deberta() + inputs = make_inputs() + with torch.no_grad(): + hf_class = model(**inputs).logits.argmax(-1).item() + qeff_model = QEFFAutoModelForSequenceClassification(model) + with torch.no_grad(): + qeff_class = qeff_model.model(**inputs).logits.argmax(-1).item() + assert hf_class == qeff_class, f"DeBERTa class mismatch: HF={hf_class}, QEff={qeff_class}" + except Exception as e: + pytest.skip(f"DeBERTa-v2 not available: {e}") + + +@pytest.mark.seq_classification +@pytest.mark.accuracy +@pytest.mark.onnx +@pytest.mark.slow +class TestSeqClassORTAccuracy: + """Full pipeline: HF → QEff → ORT must all predict the same class.""" + + def test_bert_ort_predicts_same_class_as_qeff(self, tmp_export_dir): + import onnxruntime as ort + + model, cfg = make_tiny_bert() + inputs = make_inputs() + qeff_model = QEFFAutoModelForSequenceClassification(model) + with torch.no_grad(): + qeff_class = qeff_model.model(**inputs).logits.argmax(-1).item() + onnx_path = qeff_model.export(export_dir=str(tmp_export_dir)) + session = ort.InferenceSession(str(onnx_path), providers=["CPUExecutionProvider"]) + ort_inputs = {k: v.numpy() for k, v in inputs.items()} + output_names = [o.name for o in session.get_outputs()] + ort_out = dict(zip(output_names, session.run(output_names, ort_inputs))) + ort_class = int(ort_out["logits"].argmax(-1)) + assert qeff_class == ort_class, f"Class mismatch QEff vs ORT: QEff={qeff_class}, ORT={ort_class}" + + def test_bert_ort_predicts_same_class_as_hf(self, tmp_export_dir): + import onnxruntime as ort + + model, cfg = make_tiny_bert() + inputs = make_inputs() + with torch.no_grad(): + hf_class = model(**inputs).logits.argmax(-1).item() + qeff_model = QEFFAutoModelForSequenceClassification(model) + onnx_path = qeff_model.export(export_dir=str(tmp_export_dir)) + session = ort.InferenceSession(str(onnx_path), providers=["CPUExecutionProvider"]) + ort_inputs = {k: v.numpy() for k, v in inputs.items()} + output_names = [o.name for o in session.get_outputs()] + ort_out = dict(zip(output_names, session.run(output_names, ort_inputs))) + ort_class = int(ort_out["logits"].argmax(-1)) + assert hf_class == ort_class, f"Full pipeline class mismatch: HF={hf_class}, ORT={ort_class}" + + def test_bert_ort_logits_numerically_close_to_qeff(self, tmp_export_dir): + import onnxruntime as ort + + model, cfg = make_tiny_bert() + inputs = make_inputs() + qeff_model = QEFFAutoModelForSequenceClassification(model) + with torch.no_grad(): + qeff_logits = qeff_model.model(**inputs).logits.numpy() + onnx_path = qeff_model.export(export_dir=str(tmp_export_dir)) + session = ort.InferenceSession(str(onnx_path), providers=["CPUExecutionProvider"]) + ort_inputs = {k: v.numpy() for k, v in inputs.items()} + output_names = [o.name for o in session.get_outputs()] + ort_out = dict(zip(output_names, session.run(output_names, ort_inputs))) + max_diff = np.abs(qeff_logits - ort_out["logits"]).max() + assert max_diff < 1e-4, f"Logit max diff QEff vs ORT: {max_diff:.2e}. Must be < 1e-4." + + def test_bert_ort_logits_shape_correct(self, tmp_export_dir): + import onnxruntime as ort + + model, cfg = make_tiny_bert() + qeff_model = QEFFAutoModelForSequenceClassification(model) + onnx_path = qeff_model.export(export_dir=str(tmp_export_dir)) + session = ort.InferenceSession(str(onnx_path), providers=["CPUExecutionProvider"]) + ort_inputs = {k: v.numpy() for k, v in make_inputs().items()} + output_names = [o.name for o in session.get_outputs()] + ort_out = dict(zip(output_names, session.run(output_names, ort_inputs))) + assert "logits" in ort_out + assert ort_out["logits"].shape == (1, NUM_LABELS) + + def test_bert_ort_batch_predictions_match_qeff(self, tmp_export_dir): + import onnxruntime as ort + + batch_size = 4 + model, cfg = make_tiny_bert() + inputs = make_inputs(batch=batch_size) + qeff_model = QEFFAutoModelForSequenceClassification(model) + with torch.no_grad(): + qeff_classes = qeff_model.model(**inputs).logits.argmax(-1).tolist() + onnx_path = qeff_model.export(export_dir=str(tmp_export_dir)) + session = ort.InferenceSession(str(onnx_path), providers=["CPUExecutionProvider"]) + ort_inputs = {k: v.numpy() for k, v in inputs.items()} + output_names = [o.name for o in session.get_outputs()] + ort_out = dict(zip(output_names, session.run(output_names, ort_inputs))) + ort_classes = ort_out["logits"].argmax(-1).tolist() + assert qeff_classes == ort_classes, f"Batch class mismatch: QEff={qeff_classes}, ORT={ort_classes}" + + def test_bert_onnx_passes_checker(self, tmp_export_dir): + import onnx + + model, cfg = make_tiny_bert() + qeff_model = QEFFAutoModelForSequenceClassification(model) + onnx_path = qeff_model.export(export_dir=str(tmp_export_dir)) + onnx_model = onnx.load(str(onnx_path)) + onnx.checker.check_model(onnx_model) + + def test_bert_onnx_has_input_ids_and_logits(self, tmp_export_dir): + import onnx + + model, cfg = make_tiny_bert() + qeff_model = QEFFAutoModelForSequenceClassification(model) + onnx_path = qeff_model.export(export_dir=str(tmp_export_dir)) + onnx_model = onnx.load(str(onnx_path)) + input_names = {inp.name for inp in onnx_model.graph.input} + output_names = {out.name for out in onnx_model.graph.output} + assert "input_ids" in input_names + assert "logits" in output_names + + def test_deberta_ort_predicts_same_class_as_hf(self, tmp_export_dir): + """DeBERTa-v2 full pipeline: HF, QEff, ORT must agree on class.""" + import onnxruntime as ort + + try: + model, cfg = make_tiny_deberta() + inputs = make_inputs() + with torch.no_grad(): + hf_class = model(**inputs).logits.argmax(-1).item() + qeff_model = QEFFAutoModelForSequenceClassification(model) + onnx_path = qeff_model.export(export_dir=str(tmp_export_dir)) + session = ort.InferenceSession(str(onnx_path), providers=["CPUExecutionProvider"]) + ort_inputs = {k: v.numpy() for k, v in inputs.items()} + output_names = [o.name for o in session.get_outputs()] + ort_out = dict(zip(output_names, session.run(output_names, ort_inputs))) + ort_class = int(ort_out["logits"].argmax(-1)) + assert hf_class == ort_class, f"DeBERTa pipeline mismatch: HF={hf_class}, ORT={ort_class}" + except Exception as e: + pytest.skip(f"DeBERTa-v2 not available or export failed: {e}") diff --git a/tests/unit_test/e2e/test_speech_e2e.py b/tests/unit_test/e2e/test_speech_e2e.py new file mode 100644 index 000000000..71f9b50c5 --- /dev/null +++ b/tests/unit_test/e2e/test_speech_e2e.py @@ -0,0 +1,277 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +# SPDX-License-Identifier: BSD-3-Clause +# +# ----------------------------------------------------------------------------- +""" +End-to-end tests for Speech Seq2Seq (Whisper): HF → QEff → ONNX structure. + +Key accuracy assertions: + - HF encoder produces finite hidden states with correct shape + - QEff Whisper has correct architecture (QEffWhisperEncoder, QEffWhisperDecoder) + - QEff encoder produces same hidden states as HF encoder (max_diff < 1e-5) + - QEff Whisper has QEffWhisperAttention layers + +All tests run on CPU only. +""" + +import pytest +import torch +from transformers import WhisperConfig, WhisperForConditionalGeneration + +from QEfficient.transformers.models.modeling_auto import QEFFAutoModelForSpeechSeq2Seq + +D_MODEL = 64 +NUM_MEL_BINS = 80 +VOCAB_SIZE = 100 +MAX_SOURCE_POS = 32 +MAX_TARGET_POS = 32 + + +def make_tiny_whisper(): + cfg = WhisperConfig( + vocab_size=VOCAB_SIZE, + num_mel_bins=NUM_MEL_BINS, + encoder_layers=1, + encoder_attention_heads=2, + decoder_layers=1, + decoder_attention_heads=2, + decoder_ffn_dim=D_MODEL, + encoder_ffn_dim=D_MODEL, + d_model=D_MODEL, + max_source_positions=MAX_SOURCE_POS, + max_target_positions=MAX_TARGET_POS, + decoder_start_token_id=1, + eos_token_id=2, + pad_token_id=0, + bos_token_id=1, + ) + return WhisperForConditionalGeneration(cfg).eval(), cfg + + +def make_mel_input(batch=1, seq_len=64): + return torch.randn(batch, NUM_MEL_BINS, seq_len) + + +@pytest.mark.speech +class TestHFWhisperBaseline: + """HF Whisper model runs correctly on CPU.""" + + def test_encoder_output_shape(self): + model, cfg = make_tiny_whisper() + mel = make_mel_input(seq_len=64) + with torch.no_grad(): + enc_out = model.model.encoder(mel) + assert enc_out.last_hidden_state is not None + assert enc_out.last_hidden_state.shape[-1] == D_MODEL + + def test_encoder_hidden_states_are_finite(self): + model, cfg = make_tiny_whisper() + mel = make_mel_input(seq_len=64) + with torch.no_grad(): + enc_out = model.model.encoder(mel) + assert torch.isfinite(enc_out.last_hidden_state).all() + + def test_full_forward_returns_logits(self): + model, cfg = make_tiny_whisper() + mel = make_mel_input(seq_len=64) + decoder_input_ids = torch.tensor([[cfg.decoder_start_token_id]]) + with torch.no_grad(): + out = model(input_features=mel, decoder_input_ids=decoder_input_ids) + assert hasattr(out, "logits") + assert out.logits.shape[-1] == VOCAB_SIZE + + def test_logits_are_finite(self): + model, cfg = make_tiny_whisper() + mel = make_mel_input(seq_len=64) + decoder_input_ids = torch.tensor([[cfg.decoder_start_token_id]]) + with torch.no_grad(): + out = model(input_features=mel, decoder_input_ids=decoder_input_ids) + assert torch.isfinite(out.logits).all() + + def test_generate_produces_tokens(self): + model, cfg = make_tiny_whisper() + mel = make_mel_input(seq_len=64) + with torch.no_grad(): + generated = model.generate(mel, max_new_tokens=3, do_sample=False) + assert generated is not None + assert generated.shape[0] == 1 + assert generated.shape[1] >= 1 + + def test_encoder_decoder_structure(self): + model, cfg = make_tiny_whisper() + assert hasattr(model.model, "encoder") + assert hasattr(model.model, "decoder") + + +@pytest.mark.speech +class TestQEffWhisperArchitecture: + """QEff Whisper must have correct architecture after KV transform.""" + + def test_qeff_whisper_wraps_without_error(self): + model, cfg = make_tiny_whisper() + qeff_model = QEFFAutoModelForSpeechSeq2Seq(model) + assert qeff_model is not None + assert hasattr(qeff_model, "model") + + def test_qeff_whisper_is_eval_mode(self): + model, cfg = make_tiny_whisper() + qeff_model = QEFFAutoModelForSpeechSeq2Seq(model) + assert not qeff_model.model.training + + def test_qeff_whisper_model_class_replaced(self): + from QEfficient.transformers.models.whisper.modeling_whisper import QEffWhisperForConditionalGeneration + + model, cfg = make_tiny_whisper() + qeff_model = QEFFAutoModelForSpeechSeq2Seq(model) + assert isinstance(qeff_model.model, QEffWhisperForConditionalGeneration), ( + f"Expected QEffWhisperForConditionalGeneration, got {type(qeff_model.model)}" + ) + + def test_qeff_whisper_encoder_replaced(self): + from QEfficient.transformers.models.whisper.modeling_whisper import QEffWhisperEncoder + + model, cfg = make_tiny_whisper() + qeff_model = QEFFAutoModelForSpeechSeq2Seq(model) + assert isinstance(qeff_model.model.model.encoder, QEffWhisperEncoder), ( + f"Expected QEffWhisperEncoder, got {type(qeff_model.model.model.encoder)}" + ) + + def test_qeff_whisper_decoder_replaced(self): + from QEfficient.transformers.models.whisper.modeling_whisper import QEffWhisperDecoder + + model, cfg = make_tiny_whisper() + qeff_model = QEFFAutoModelForSpeechSeq2Seq(model) + assert isinstance(qeff_model.model.model.decoder, QEffWhisperDecoder), ( + f"Expected QEffWhisperDecoder, got {type(qeff_model.model.model.decoder)}" + ) + + def test_qeff_whisper_has_qeff_attention_layers(self): + from QEfficient.transformers.models.whisper.modeling_whisper import QEffWhisperAttention + + model, cfg = make_tiny_whisper() + qeff_model = QEFFAutoModelForSpeechSeq2Seq(model) + has_qeff_attn = any(isinstance(m, QEffWhisperAttention) for m in qeff_model.model.modules()) + assert has_qeff_attn, "QEff Whisper must have QEffWhisperAttention layers" + + def test_qeff_whisper_has_positional_embedding_replaced(self): + from QEfficient.transformers.models.whisper.modeling_whisper import QEffWhisperPositionalEmbedding + + model, cfg = make_tiny_whisper() + qeff_model = QEFFAutoModelForSpeechSeq2Seq(model) + has_pos_emb = any(isinstance(m, QEffWhisperPositionalEmbedding) for m in qeff_model.model.modules()) + assert has_pos_emb, "QEff Whisper must have QEffWhisperPositionalEmbedding" + + def test_qeff_whisper_model_name_property(self): + model, cfg = make_tiny_whisper() + qeff_model = QEFFAutoModelForSpeechSeq2Seq(model) + assert hasattr(qeff_model, "model_name") + assert isinstance(qeff_model.model_name, str) + assert len(qeff_model.model_name) > 0 + + +@pytest.mark.speech +@pytest.mark.accuracy +class TestQEffWhisperEncoderAccuracy: + """QEff Whisper encoder must produce the same hidden states as HF encoder.""" + + def test_qeff_encoder_output_shape_matches_hf(self): + model, cfg = make_tiny_whisper() + mel = make_mel_input(seq_len=64) + with torch.no_grad(): + hf_enc = model.model.encoder(mel) + qeff_model = QEFFAutoModelForSpeechSeq2Seq(model) + with torch.no_grad(): + qeff_enc = qeff_model.model.model.encoder(mel) + assert qeff_enc.last_hidden_state.shape == hf_enc.last_hidden_state.shape + + def test_qeff_encoder_hidden_states_match_hf(self): + """QEff encoder hidden states must be numerically identical to HF.""" + model, cfg = make_tiny_whisper() + mel = make_mel_input(seq_len=64) + with torch.no_grad(): + hf_hidden = model.model.encoder(mel).last_hidden_state + qeff_model = QEFFAutoModelForSpeechSeq2Seq(model) + with torch.no_grad(): + qeff_hidden = qeff_model.model.model.encoder(mel).last_hidden_state + max_diff = (hf_hidden - qeff_hidden).abs().max().item() + assert max_diff < 1e-5, ( + f"Encoder hidden state mismatch: max_diff={max_diff:.2e}. " + f"QEff encoder must produce identical outputs to HF encoder." + ) + + def test_qeff_encoder_hidden_states_are_finite(self): + model, cfg = make_tiny_whisper() + mel = make_mel_input(seq_len=64) + qeff_model = QEFFAutoModelForSpeechSeq2Seq(model) + with torch.no_grad(): + qeff_enc = qeff_model.model.model.encoder(mel) + assert torch.isfinite(qeff_enc.last_hidden_state).all() + + def test_qeff_encoder_deterministic(self): + model, cfg = make_tiny_whisper() + mel = make_mel_input(seq_len=64) + qeff_model = QEFFAutoModelForSpeechSeq2Seq(model) + with torch.no_grad(): + h1 = qeff_model.model.model.encoder(mel).last_hidden_state + h2 = qeff_model.model.model.encoder(mel).last_hidden_state + assert torch.allclose(h1, h2), "QEff encoder must be deterministic" + + def test_qeff_encoder_batch_output_shape(self): + """QEff encoder must handle batch_size > 1.""" + model, cfg = make_tiny_whisper() + mel = make_mel_input(batch=2, seq_len=64) + qeff_model = QEFFAutoModelForSpeechSeq2Seq(model) + with torch.no_grad(): + qeff_enc = qeff_model.model.model.encoder(mel) + assert qeff_enc.last_hidden_state.shape[0] == 2 + assert torch.isfinite(qeff_enc.last_hidden_state).all() + + +@pytest.mark.speech +@pytest.mark.onnx +@pytest.mark.slow +class TestWhisperONNXExport: + """Whisper ONNX export tests.""" + + def test_whisper_onnx_export_succeeds(self, tmp_export_dir): + model, cfg = make_tiny_whisper() + qeff_model = QEFFAutoModelForSpeechSeq2Seq(model) + onnx_path = qeff_model.export(export_dir=str(tmp_export_dir)) + assert onnx_path is not None + + def test_whisper_onnx_files_exist(self, tmp_export_dir): + import pathlib + + model, cfg = make_tiny_whisper() + qeff_model = QEFFAutoModelForSpeechSeq2Seq(model) + onnx_path = qeff_model.export(export_dir=str(tmp_export_dir)) + search_root = pathlib.Path(str(onnx_path)).parent if onnx_path else tmp_export_dir + onnx_files = list(search_root.rglob("*.onnx")) or list(tmp_export_dir.rglob("*.onnx")) + assert len(onnx_files) > 0, ( + f"No ONNX files found after Whisper export. onnx_path={onnx_path}, search_root={search_root}" + ) + + def test_whisper_onnx_encoder_passes_checker(self, tmp_export_dir): + """At least one exported Whisper ONNX file must pass onnx.checker.""" + import pathlib + + import onnx + + model, cfg = make_tiny_whisper() + qeff_model = QEFFAutoModelForSpeechSeq2Seq(model) + onnx_path = qeff_model.export(export_dir=str(tmp_export_dir)) + search_root = pathlib.Path(str(onnx_path)).parent if onnx_path else tmp_export_dir + onnx_files = list(search_root.rglob("*.onnx")) or list(tmp_export_dir.rglob("*.onnx")) + assert len(onnx_files) > 0, "No ONNX files found after Whisper export" + passed = False + for f in onnx_files: + try: + m = onnx.load(str(f)) + onnx.checker.check_model(m) + passed = True + break + except Exception: + continue + assert passed, "No exported Whisper ONNX file passed onnx.checker" diff --git a/tests/unit_test/e2e/test_vlm_e2e.py b/tests/unit_test/e2e/test_vlm_e2e.py new file mode 100644 index 000000000..a4901c5ac --- /dev/null +++ b/tests/unit_test/e2e/test_vlm_e2e.py @@ -0,0 +1,413 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +# SPDX-License-Identifier: BSD-3-Clause +# +# ----------------------------------------------------------------------------- +""" +Tests for VLM (Vision-Language Model) pipeline in QEfficient. + +Tests verify: + - QEFFAutoModelForImageTextToText: importable, has correct class structure + - kv_offload=True routes to _QEffAutoModelForImageTextToTextDualQPC + - kv_offload=False routes to _QEFFAutoModelForImageTextToTextSingleQPC + - MISCLASSIFIED_CAUSAL_LM_TO_QEFF_AUTO_CLASS_MAP: exists and is a dict + - QEFFAutoModelForCTC: importable, has correct class structure + - VlmKVOffloadTransform / VlmNoKVOffloadTransform: importable, have module mappings + +All tests run on CPU , using tiny in-memory configs where possible. +""" + +import pytest + +# --------------------------------------------------------------------------- +# Tests: QEFFAutoModelForImageTextToText class structure +# --------------------------------------------------------------------------- + + +class TestQEFFAutoModelForImageTextToTextStructure: + """QEFFAutoModelForImageTextToText must have correct class-level structure.""" + + def test_importable(self): + from QEfficient.transformers.models.modeling_auto import QEFFAutoModelForImageTextToText + + assert QEFFAutoModelForImageTextToText is not None + + def test_dual_qpc_class_importable(self): + from QEfficient.transformers.models.modeling_auto import _QEffAutoModelForImageTextToTextDualQPC + + assert _QEffAutoModelForImageTextToTextDualQPC is not None + + def test_single_qpc_class_importable(self): + from QEfficient.transformers.models.modeling_auto import _QEFFAutoModelForImageTextToTextSingleQPC + + assert _QEFFAutoModelForImageTextToTextSingleQPC is not None + + def test_dual_qpc_has_from_pretrained(self): + from QEfficient.transformers.models.modeling_auto import _QEffAutoModelForImageTextToTextDualQPC + + assert hasattr(_QEffAutoModelForImageTextToTextDualQPC, "from_pretrained") + assert callable(_QEffAutoModelForImageTextToTextDualQPC.from_pretrained) + + def test_single_qpc_has_from_pretrained(self): + from QEfficient.transformers.models.modeling_auto import _QEFFAutoModelForImageTextToTextSingleQPC + + assert hasattr(_QEFFAutoModelForImageTextToTextSingleQPC, "from_pretrained") + assert callable(_QEFFAutoModelForImageTextToTextSingleQPC.from_pretrained) + + def test_dual_qpc_has_from_pretrained_classmethod(self): + from QEfficient.transformers.models.modeling_auto import _QEffAutoModelForImageTextToTextDualQPC + + assert hasattr(_QEffAutoModelForImageTextToTextDualQPC, "from_pretrained") + assert callable(_QEffAutoModelForImageTextToTextDualQPC.from_pretrained) + + def test_single_qpc_has_pytorch_transforms(self): + from QEfficient.transformers.models.modeling_auto import _QEFFAutoModelForImageTextToTextSingleQPC + + assert hasattr(_QEFFAutoModelForImageTextToTextSingleQPC, "_pytorch_transforms") + assert isinstance(_QEFFAutoModelForImageTextToTextSingleQPC._pytorch_transforms, list) + + def test_dual_qpc_has_model_attribute_after_construction(self): + """_QEffAutoModelForImageTextToTextDualQPC instances must have a model attribute.""" + from QEfficient.transformers.models.modeling_auto import ( + QEFFAutoModelForImageTextToText, + _QEffAutoModelForImageTextToTextDualQPC, + ) + + try: + from transformers import CLIPVisionConfig, LlamaConfig, LlavaConfig, LlavaForConditionalGeneration + + vision_cfg = CLIPVisionConfig( + hidden_size=64, + intermediate_size=128, + num_hidden_layers=1, + num_attention_heads=2, + image_size=32, + patch_size=16, + ) + text_cfg = LlamaConfig( + num_hidden_layers=1, + num_attention_heads=2, + num_key_value_heads=2, + hidden_size=64, + intermediate_size=128, + vocab_size=500, + max_position_embeddings=64, + ) + llava_cfg = LlavaConfig( + vision_config=vision_cfg, + text_config=text_cfg, + ignore_index=-100, + image_token_index=32000, + projector_hidden_act="gelu", + vision_feature_select_strategy="default", + vision_feature_layer=-1, + ) + model = LlavaForConditionalGeneration(llava_cfg).eval() + qeff = QEFFAutoModelForImageTextToText(model, kv_offload=True) + assert isinstance(qeff, _QEffAutoModelForImageTextToTextDualQPC) + assert hasattr(qeff, "model") + except Exception as e: + pytest.skip(f"Cannot create DualQPC instance: {e}") + + def test_single_qpc_has_onnx_transforms(self): + from QEfficient.transformers.models.modeling_auto import _QEFFAutoModelForImageTextToTextSingleQPC + + assert hasattr(_QEFFAutoModelForImageTextToTextSingleQPC, "_onnx_transforms") + assert isinstance(_QEFFAutoModelForImageTextToTextSingleQPC._onnx_transforms, list) + + def test_dual_qpc_has_hf_auto_class(self): + from QEfficient.transformers.models.modeling_auto import _QEffAutoModelForImageTextToTextDualQPC + + assert hasattr(_QEffAutoModelForImageTextToTextDualQPC, "_hf_auto_class") + + def test_single_qpc_has_hf_auto_class(self): + from QEfficient.transformers.models.modeling_auto import _QEFFAutoModelForImageTextToTextSingleQPC + + assert hasattr(_QEFFAutoModelForImageTextToTextSingleQPC, "_hf_auto_class") + + def test_importable_from_qefficient_public_api(self): + import QEfficient + + assert hasattr(QEfficient, "QEFFAutoModelForImageTextToText") + + +# --------------------------------------------------------------------------- +# Tests: QEFFAutoModelForImageTextToText routing +# --------------------------------------------------------------------------- + + +class TestQEFFAutoModelForImageTextToTextRouting: + """QEFFAutoModelForImageTextToText must route to correct class based on kv_offload.""" + + def _make_tiny_llava(self): + """Create a tiny LLaVA model for routing tests.""" + try: + from transformers import CLIPVisionConfig, LlamaConfig, LlavaConfig, LlavaForConditionalGeneration + + vision_cfg = CLIPVisionConfig( + hidden_size=64, + intermediate_size=128, + num_hidden_layers=1, + num_attention_heads=2, + image_size=32, + patch_size=16, + ) + text_cfg = LlamaConfig( + num_hidden_layers=1, + num_attention_heads=2, + num_key_value_heads=2, + hidden_size=64, + intermediate_size=128, + vocab_size=500, + max_position_embeddings=64, + ) + llava_cfg = LlavaConfig( + vision_config=vision_cfg, + text_config=text_cfg, + ignore_index=-100, + image_token_index=32000, + projector_hidden_act="gelu", + vision_feature_select_strategy="default", + vision_feature_layer=-1, + ) + return LlavaForConditionalGeneration(llava_cfg).eval() + except Exception as e: + pytest.skip(f"Cannot create tiny LLaVA model: {e}") + + def test_kv_offload_false_creates_single_qpc(self): + """kv_offload=False must create _QEFFAutoModelForImageTextToTextSingleQPC.""" + from QEfficient.transformers.models.modeling_auto import ( + QEFFAutoModelForImageTextToText, + _QEFFAutoModelForImageTextToTextSingleQPC, + ) + + model = self._make_tiny_llava() + qeff = QEFFAutoModelForImageTextToText(model, kv_offload=False) + assert isinstance(qeff, _QEFFAutoModelForImageTextToTextSingleQPC), ( + f"kv_offload=False must create SingleQPC, got {type(qeff)}" + ) + + def test_kv_offload_true_creates_dual_qpc(self): + """kv_offload=True must create _QEffAutoModelForImageTextToTextDualQPC.""" + from QEfficient.transformers.models.modeling_auto import ( + QEFFAutoModelForImageTextToText, + _QEffAutoModelForImageTextToTextDualQPC, + ) + + model = self._make_tiny_llava() + qeff = QEFFAutoModelForImageTextToText(model, kv_offload=True) + assert isinstance(qeff, _QEffAutoModelForImageTextToTextDualQPC), ( + f"kv_offload=True must create DualQPC, got {type(qeff)}" + ) + + def test_default_kv_offload_creates_dual_qpc(self): + """Default kv_offload (None/True) must create _QEffAutoModelForImageTextToTextDualQPC.""" + from QEfficient.transformers.models.modeling_auto import ( + QEFFAutoModelForImageTextToText, + _QEffAutoModelForImageTextToTextDualQPC, + ) + + model = self._make_tiny_llava() + qeff = QEFFAutoModelForImageTextToText(model) + assert isinstance(qeff, _QEffAutoModelForImageTextToTextDualQPC), "Default kv_offload must create DualQPC" + + def test_single_qpc_has_model_attribute(self): + from QEfficient.transformers.models.modeling_auto import QEFFAutoModelForImageTextToText + + model = self._make_tiny_llava() + qeff = QEFFAutoModelForImageTextToText(model, kv_offload=False) + assert hasattr(qeff, "model") + + def test_dual_qpc_has_model_attribute(self): + from QEfficient.transformers.models.modeling_auto import QEFFAutoModelForImageTextToText + + model = self._make_tiny_llava() + qeff = QEFFAutoModelForImageTextToText(model, kv_offload=True) + assert hasattr(qeff, "model") + + def test_single_qpc_model_name_is_string(self): + from QEfficient.transformers.models.modeling_auto import QEFFAutoModelForImageTextToText + + model = self._make_tiny_llava() + qeff = QEFFAutoModelForImageTextToText(model, kv_offload=False) + assert hasattr(qeff, "model_name") + assert isinstance(qeff.model_name, str) + assert len(qeff.model_name) > 0 + + +# --------------------------------------------------------------------------- +# Tests: MISCLASSIFIED_CAUSAL_LM_TO_QEFF_AUTO_CLASS_MAP +# --------------------------------------------------------------------------- + + +class TestMisclassifiedCausalLMMap: + """MISCLASSIFIED_CAUSAL_LM_TO_QEFF_AUTO_CLASS_MAP must exist and route correctly.""" + + def test_map_exists_and_is_dict(self): + from QEfficient.transformers.models.modeling_auto import ( + MISCLASSIFIED_CAUSAL_LM_TO_QEFF_AUTO_CLASS_MAP, + ) + + assert isinstance(MISCLASSIFIED_CAUSAL_LM_TO_QEFF_AUTO_CLASS_MAP, dict) + + def test_map_values_are_qeff_classes(self): + from QEfficient.transformers.models.modeling_auto import ( + MISCLASSIFIED_CAUSAL_LM_TO_QEFF_AUTO_CLASS_MAP, + ) + + for key, val in MISCLASSIFIED_CAUSAL_LM_TO_QEFF_AUTO_CLASS_MAP.items(): + assert isinstance(val, type), f"Expected class for key '{key}', got {type(val)}" + + def test_map_keys_are_strings(self): + from QEfficient.transformers.models.modeling_auto import ( + MISCLASSIFIED_CAUSAL_LM_TO_QEFF_AUTO_CLASS_MAP, + ) + + for key in MISCLASSIFIED_CAUSAL_LM_TO_QEFF_AUTO_CLASS_MAP.keys(): + assert isinstance(key, str), f"Expected string key, got {type(key)}: {key}" + + +# --------------------------------------------------------------------------- +# Tests: QEFFAutoModelForCTC class structure +# --------------------------------------------------------------------------- + + +class TestQEFFAutoModelForCTCStructure: + """QEFFAutoModelForCTC must have correct class-level structure.""" + + def test_importable(self): + from QEfficient.transformers.models.modeling_auto import QEFFAutoModelForCTC + + assert QEFFAutoModelForCTC is not None + + def test_has_from_pretrained(self): + from QEfficient.transformers.models.modeling_auto import QEFFAutoModelForCTC + + assert hasattr(QEFFAutoModelForCTC, "from_pretrained") + assert callable(QEFFAutoModelForCTC.from_pretrained) + + def test_has_pytorch_transforms(self): + from QEfficient.transformers.models.modeling_auto import QEFFAutoModelForCTC + + assert hasattr(QEFFAutoModelForCTC, "_pytorch_transforms") + assert isinstance(QEFFAutoModelForCTC._pytorch_transforms, list) + + def test_has_onnx_transforms(self): + from QEfficient.transformers.models.modeling_auto import QEFFAutoModelForCTC + + assert hasattr(QEFFAutoModelForCTC, "_onnx_transforms") + assert isinstance(QEFFAutoModelForCTC._onnx_transforms, list) + + def test_has_hf_auto_class(self): + from QEfficient.transformers.models.modeling_auto import QEFFAutoModelForCTC + + assert hasattr(QEFFAutoModelForCTC, "_hf_auto_class") + + def test_hf_auto_class_is_auto_model_for_ctc(self): + from transformers import AutoModelForCTC + + from QEfficient.transformers.models.modeling_auto import QEFFAutoModelForCTC + + assert QEFFAutoModelForCTC._hf_auto_class is AutoModelForCTC + + def test_pytorch_transforms_include_custom_ops_transform(self): + from QEfficient.transformers.models.modeling_auto import QEFFAutoModelForCTC + from QEfficient.transformers.models.pytorch_transforms import CustomOpsTransform + + assert CustomOpsTransform in QEFFAutoModelForCTC._pytorch_transforms, ( + "CustomOpsTransform not in QEFFAutoModelForCTC._pytorch_transforms" + ) + + def test_onnx_transforms_include_fp16_clip(self): + from QEfficient.base.onnx_transforms import FP16ClipTransform + from QEfficient.transformers.models.modeling_auto import QEFFAutoModelForCTC + + assert FP16ClipTransform in QEFFAutoModelForCTC._onnx_transforms, ( + "FP16ClipTransform not in QEFFAutoModelForCTC._onnx_transforms" + ) + + +# --------------------------------------------------------------------------- +# Tests: VLM KV Offload Transforms +# --------------------------------------------------------------------------- + + +class TestVlmKVOffloadTransforms: + """VlmKVOffloadTransform and VlmNoKVOffloadTransform must have correct structure.""" + + def test_vlm_kv_offload_transform_importable(self): + from QEfficient.transformers.models.pytorch_transforms import VlmKVOffloadTransform + + assert VlmKVOffloadTransform is not None + + def test_vlm_no_kv_offload_transform_importable(self): + from QEfficient.transformers.models.pytorch_transforms import VlmNoKVOffloadTransform + + assert VlmNoKVOffloadTransform is not None + + def test_vlm_kv_offload_has_module_mapping(self): + from QEfficient.transformers.models.pytorch_transforms import VlmKVOffloadTransform + + assert hasattr(VlmKVOffloadTransform, "_module_mapping") + assert len(VlmKVOffloadTransform._module_mapping) > 0 + + def test_vlm_no_kv_offload_has_module_mapping(self): + from QEfficient.transformers.models.pytorch_transforms import VlmNoKVOffloadTransform + + assert hasattr(VlmNoKVOffloadTransform, "_module_mapping") + assert len(VlmNoKVOffloadTransform._module_mapping) > 0 + + def test_vlm_kv_offload_maps_mllama_cross_attention_to_two_qpc(self): + from transformers.models.mllama.modeling_mllama import MllamaTextCrossAttention + + from QEfficient.transformers.models.mllama.modeling_mllama import ( + QEffMllamaTextCrossAttentionTwoQPC, + ) + from QEfficient.transformers.models.pytorch_transforms import VlmKVOffloadTransform + + assert MllamaTextCrossAttention in VlmKVOffloadTransform._module_mapping + assert VlmKVOffloadTransform._module_mapping[MllamaTextCrossAttention] is QEffMllamaTextCrossAttentionTwoQPC + + def test_vlm_no_kv_offload_maps_mllama_cross_attention_to_single_qpc(self): + from transformers.models.mllama.modeling_mllama import MllamaTextCrossAttention + + from QEfficient.transformers.models.mllama.modeling_mllama import ( + QEffMllamaTextCrossAttentionSingleQPC, + ) + from QEfficient.transformers.models.pytorch_transforms import VlmNoKVOffloadTransform + + assert MllamaTextCrossAttention in VlmNoKVOffloadTransform._module_mapping + assert ( + VlmNoKVOffloadTransform._module_mapping[MllamaTextCrossAttention] is QEffMllamaTextCrossAttentionSingleQPC + ) + + def test_vlm_kv_offload_has_apply_method(self): + from QEfficient.transformers.models.pytorch_transforms import VlmKVOffloadTransform + + assert hasattr(VlmKVOffloadTransform, "apply") + assert callable(VlmKVOffloadTransform.apply) + + def test_vlm_no_kv_offload_has_apply_method(self): + from QEfficient.transformers.models.pytorch_transforms import VlmNoKVOffloadTransform + + assert hasattr(VlmNoKVOffloadTransform, "apply") + assert callable(VlmNoKVOffloadTransform.apply) + + def test_single_qpc_pytorch_transforms_include_kv_offload_transform(self): + """SingleQPC must use VlmNoKVOffloadTransform in its pytorch transforms.""" + from QEfficient.transformers.models.modeling_auto import _QEFFAutoModelForImageTextToTextSingleQPC + from QEfficient.transformers.models.pytorch_transforms import VlmNoKVOffloadTransform + + assert VlmNoKVOffloadTransform in _QEFFAutoModelForImageTextToTextSingleQPC._pytorch_transforms, ( + "VlmNoKVOffloadTransform not in SingleQPC._pytorch_transforms" + ) + + def test_single_qpc_pytorch_transforms_include_no_kv_offload(self): + """SingleQPC must use VlmNoKVOffloadTransform in its pytorch transforms.""" + from QEfficient.transformers.models.modeling_auto import _QEFFAutoModelForImageTextToTextSingleQPC + from QEfficient.transformers.models.pytorch_transforms import VlmNoKVOffloadTransform + + assert VlmNoKVOffloadTransform in _QEFFAutoModelForImageTextToTextSingleQPC._pytorch_transforms, ( + "VlmNoKVOffloadTransform not in SingleQPC._pytorch_transforms" + ) diff --git a/tests/unit_test/models/__init__.py b/tests/unit_test/models/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/unit_test/models/test_cache_correctness.py b/tests/unit_test/models/test_cache_correctness.py new file mode 100644 index 000000000..a1e14ed5f --- /dev/null +++ b/tests/unit_test/models/test_cache_correctness.py @@ -0,0 +1,401 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +# SPDX-License-Identifier: BSD-3-Clause +# +# ----------------------------------------------------------------------------- +""" +Correctness tests for QEfficient cache utilities. + +Tests verify numerical correctness of: + - QEffDynamicLayer: scatter/gather round-trip + - QEffDynamicCache: multi-layer update, write/read, prefill+decode + - QEffEncoderDecoderCache: from_legacy_cache + - InvalidIndexProvider: value logic + +All tests run on CPU only. +""" + +import pytest +import torch + +from QEfficient.transformers.cache_utils import ( + InvalidIndexProvider, + QEffDynamicCache, + QEffDynamicLayer, + QEffEncoderDecoderCache, +) + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def make_kv(batch=1, heads=2, seq=8, head_dim=16): + k = torch.randn(batch, heads, seq, head_dim) + v = torch.randn(batch, heads, seq, head_dim) + return k, v + + +def pos_ids(batch=1, seq=8, start=0): + return torch.arange(start, start + seq).unsqueeze(0).expand(batch, -1) + + +# --------------------------------------------------------------------------- +# Tests: InvalidIndexProvider +# --------------------------------------------------------------------------- + + +@pytest.mark.cache +class TestInvalidIndexProvider: + """InvalidIndexProvider must return 0 outside ONNX export.""" + + def test_returns_zero_outside_onnx_export(self): + val = InvalidIndexProvider._get_invalid_idx_value() + assert val == 0, f"Expected 0 outside ONNX export, got {val}" + + def test_subfunc_disabled_by_default(self): + assert InvalidIndexProvider.SUBFUNC_ENABLED is False + + def test_enable_subfunc_sets_flag(self): + original = InvalidIndexProvider.SUBFUNC_ENABLED + try: + InvalidIndexProvider.enable_subfunc() + assert InvalidIndexProvider.SUBFUNC_ENABLED is True + finally: + InvalidIndexProvider.SUBFUNC_ENABLED = original + + +# --------------------------------------------------------------------------- +# Tests: QEffDynamicLayer +# --------------------------------------------------------------------------- + + +@pytest.mark.cache +class TestQEffDynamicLayerCorrectness: + """QEffDynamicLayer scatter/gather must be numerically correct.""" + + def test_initial_state_is_none(self): + layer = QEffDynamicLayer() + assert layer.keys is None + assert layer.values is None + + def test_first_update_stores_tensors(self): + layer = QEffDynamicLayer() + k, v = make_kv(seq=8) + k_out, v_out = layer.update(k, v, cache_kwargs={"position_ids": pos_ids(seq=8)}) + assert layer.keys is not None + assert layer.values is not None + assert k_out.shape == k.shape + assert v_out.shape == v.shape + + def test_write_then_read_returns_same_values(self): + """write_only then read_only must return the exact same tensors.""" + layer = QEffDynamicLayer() + k, v = make_kv(batch=1, heads=2, seq=8, head_dim=16) + pids = pos_ids(seq=8) + + layer.write_only(k, v, cache_kwargs={"position_ids": pids}) + k_out, v_out = layer.read_only(cache_kwargs={"position_ids": pids}) + + assert k_out.shape == k.shape + assert v_out.shape == v.shape + assert torch.allclose(k_out, k), "read_only must return the same keys as written" + assert torch.allclose(v_out, v), "read_only must return the same values as written" + + def test_update_output_has_ctx_len_dimension(self): + """After update, output must have the context length dimension.""" + layer = QEffDynamicLayer() + batch, heads, ctx_len, head_dim = 1, 2, 16, 8 + k = torch.zeros(batch, heads, ctx_len, head_dim) + v = torch.zeros(batch, heads, ctx_len, head_dim) + pids = pos_ids(seq=ctx_len) + + k_out, v_out = layer.update(k, v, cache_kwargs={"position_ids": pids}) + assert k_out.shape == (batch, heads, ctx_len, head_dim) + assert v_out.shape == (batch, heads, ctx_len, head_dim) + + def test_decode_step_scatter_at_correct_position(self): + """Decode step must scatter the new token at the correct position.""" + layer = QEffDynamicLayer() + batch, heads, ctx_len, head_dim = 1, 2, 16, 8 + + # Initialize with zeros + k_init = torch.zeros(batch, heads, ctx_len, head_dim) + v_init = torch.zeros(batch, heads, ctx_len, head_dim) + layer.update(k_init, v_init, cache_kwargs={"position_ids": pos_ids(seq=ctx_len)}) + + # Decode: write a known value at position 5 + k_new = torch.ones(batch, heads, 1, head_dim) * 7.0 + v_new = torch.ones(batch, heads, 1, head_dim) * 7.0 + pos_decode = torch.tensor([[5]]) + + k_out, v_out = layer.update(k_new, v_new, cache_kwargs={"position_ids": pos_decode}) + + assert k_out.shape[2] == ctx_len + assert k_out[0, 0, 5, 0].item() == pytest.approx(7.0, abs=1e-5), ( + f"Expected 7.0 at position 5, got {k_out[0, 0, 5, 0].item()}" + ) + + def test_update_output_is_finite(self): + layer = QEffDynamicLayer() + k, v = make_kv(seq=8) + k_out, v_out = layer.update(k, v, cache_kwargs={"position_ids": pos_ids(seq=8)}) + assert torch.isfinite(k_out).all() + assert torch.isfinite(v_out).all() + + +# --------------------------------------------------------------------------- +# Tests: QEffDynamicCache +# --------------------------------------------------------------------------- + + +@pytest.mark.cache +class TestQEffDynamicCacheCorrectness: + """QEffDynamicCache must correctly manage multiple layers.""" + + def test_empty_cache_creation(self): + cache = QEffDynamicCache() + assert cache is not None + + def test_update_adds_layer(self): + cache = QEffDynamicCache() + k, v = make_kv(seq=8) + k_out, v_out = cache.update(k, v, layer_idx=0, cache_kwargs={"position_ids": pos_ids(seq=8)}) + assert k_out is not None + assert v_out is not None + + def test_update_multiple_layers_creates_correct_count(self): + cache = QEffDynamicCache() + for i in range(4): + k, v = make_kv(seq=8) + cache.update(k, v, layer_idx=i, cache_kwargs={"position_ids": pos_ids(seq=8)}) + assert len(cache.layers) == 4 + + def test_layers_are_qeff_dynamic_layer_instances(self): + cache = QEffDynamicCache() + k, v = make_kv(seq=8) + cache.update(k, v, layer_idx=0, cache_kwargs={"position_ids": pos_ids(seq=8)}) + assert isinstance(cache.layers[0], QEffDynamicLayer) + + def test_write_only_then_read_only_returns_same_values(self): + """write_only + read_only round-trip must return identical tensors.""" + cache = QEffDynamicCache() + k, v = make_kv(batch=1, heads=2, seq=8, head_dim=16) + pids = pos_ids(seq=8) + + cache.write_only(k, v, layer_idx=0, cache_kwargs={"position_ids": pids}) + k_out, v_out = cache.read_only(layer_idx=0, cache_kwargs={"position_ids": pids}) + + assert torch.allclose(k_out, k), "read_only must return the same keys as written" + assert torch.allclose(v_out, v), "read_only must return the same values as written" + + def test_prefill_then_decode_produces_finite_outputs(self): + """Prefill + decode must produce finite key/value tensors.""" + cache = QEffDynamicCache() + batch, heads, ctx_len, head_dim = 1, 2, 16, 8 + + k_prefill = torch.randn(batch, heads, ctx_len, head_dim) + v_prefill = torch.randn(batch, heads, ctx_len, head_dim) + cache.update(k_prefill, v_prefill, layer_idx=0, cache_kwargs={"position_ids": pos_ids(seq=ctx_len)}) + + k_decode = torch.randn(batch, heads, 1, head_dim) + v_decode = torch.randn(batch, heads, 1, head_dim) + pos_decode = torch.tensor([[ctx_len - 1]]) + + k_out, v_out = cache.update(k_decode, v_decode, layer_idx=0, cache_kwargs={"position_ids": pos_decode}) + + assert torch.isfinite(k_out).all() + assert torch.isfinite(v_out).all() + assert k_out.shape[2] == ctx_len + + def test_decode_scatter_at_correct_position(self): + """Decode must scatter the new token at the correct position in the cache.""" + cache = QEffDynamicCache() + batch, heads, ctx_len, head_dim = 1, 2, 16, 8 + + k_prefill = torch.zeros(batch, heads, ctx_len, head_dim) + v_prefill = torch.zeros(batch, heads, ctx_len, head_dim) + cache.update(k_prefill, v_prefill, layer_idx=0, cache_kwargs={"position_ids": pos_ids(seq=ctx_len)}) + + k_decode = torch.ones(batch, heads, 1, head_dim) * 42.0 + v_decode = torch.ones(batch, heads, 1, head_dim) * 42.0 + pos_decode = torch.tensor([[3]]) + + k_out, v_out = cache.update(k_decode, v_decode, layer_idx=0, cache_kwargs={"position_ids": pos_decode}) + + assert k_out[0, 0, 3, 0].item() == pytest.approx(42.0, abs=1e-5), ( + f"Expected 42.0 at position 3, got {k_out[0, 0, 3, 0].item()}" + ) + + def test_ddp_cache_data_populates_layers(self): + """QEffDynamicCache with ddp_cache_data must populate layers.""" + k, v = make_kv(seq=8) + ddp_data = [(k, v), (k.clone(), v.clone())] + cache = QEffDynamicCache(ddp_cache_data=ddp_data) + assert len(cache.layers) >= 2 + + def test_batch_index_continuous_batching_mode(self): + """Cache update with batch_index (continuous batching) must work.""" + cache = QEffDynamicCache() + batch, heads, ctx_len, head_dim = 2, 2, 8, 4 + + k = torch.zeros(batch, heads, ctx_len, head_dim) + v = torch.zeros(batch, heads, ctx_len, head_dim) + pids = pos_ids(batch=batch, seq=ctx_len) + batch_index = torch.arange(batch).view(-1, 1) + + k_out, v_out = cache.update(k, v, layer_idx=0, cache_kwargs={"position_ids": pids, "batch_index": batch_index}) + assert k_out is not None + assert v_out is not None + assert torch.isfinite(k_out).all() + + +# --------------------------------------------------------------------------- +# Tests: QEffEncoderDecoderCache +# --------------------------------------------------------------------------- + + +@pytest.mark.cache +class TestQEffEncoderDecoderCacheCorrectness: + """QEffEncoderDecoderCache must correctly initialize from legacy cache.""" + + def test_from_legacy_cache_none_creates_empty_cache(self): + cache = QEffEncoderDecoderCache.from_legacy_cache(past_key_values=None) + assert cache is not None + assert isinstance(cache.self_attention_cache, QEffDynamicCache) + assert isinstance(cache.cross_attention_cache, QEffDynamicCache) + + def test_from_legacy_cache_with_2tuple_populates_self_attention(self): + k, v = make_kv(seq=8) + past = [(k, v), (k.clone(), v.clone())] + cache = QEffEncoderDecoderCache.from_legacy_cache(past_key_values=past) + assert cache is not None + + def test_from_legacy_cache_with_4tuple_populates_cross_attention(self): + k, v = make_kv(seq=8) + past = [(k, v, k.clone(), v.clone())] + cache = QEffEncoderDecoderCache.from_legacy_cache(past_key_values=past) + assert cache is not None + + +# --------------------------------------------------------------------------- +# Tests: Cache numerical correctness (scatter/gather round-trip) +# --------------------------------------------------------------------------- + + +@pytest.mark.cache +@pytest.mark.accuracy +class TestCacheScatterGatherNumericalCorrectness: + """ + Scatter/gather operations must be numerically correct. + These tests verify that the cache correctly stores and retrieves values. + """ + + def test_prefill_values_preserved_in_cache(self): + """After prefill, the cache must contain the exact prefill values.""" + cache = QEffDynamicCache() + batch, heads, ctx_len, head_dim = 1, 2, 16, 8 + + k = torch.arange(batch * heads * ctx_len * head_dim, dtype=torch.float32).reshape( + batch, heads, ctx_len, head_dim + ) + v = k * 2.0 + pids = pos_ids(seq=ctx_len) + + cache.write_only(k, v, layer_idx=0, cache_kwargs={"position_ids": pids}) + k_out, v_out = cache.read_only(layer_idx=0, cache_kwargs={"position_ids": pids}) + + assert torch.allclose(k_out, k), "Cache must preserve exact prefill key values" + assert torch.allclose(v_out, v), "Cache must preserve exact prefill value values" + + def test_decode_overwrites_correct_position(self): + """Decode step must overwrite exactly the specified position.""" + cache = QEffDynamicCache() + batch, heads, ctx_len, head_dim = 1, 2, 16, 4 + + k_prefill = torch.zeros(batch, heads, ctx_len, head_dim) + v_prefill = torch.zeros(batch, heads, ctx_len, head_dim) + cache.update(k_prefill, v_prefill, layer_idx=0, cache_kwargs={"position_ids": pos_ids(seq=ctx_len)}) + + k_decode = torch.ones(batch, heads, 1, head_dim) * 99.0 + v_decode = torch.ones(batch, heads, 1, head_dim) * 99.0 + pos_decode = torch.tensor([[7]]) + + k_out, v_out = cache.update(k_decode, v_decode, layer_idx=0, cache_kwargs={"position_ids": pos_decode}) + + # Position 7 must have 99.0 + assert k_out[0, 0, 7, 0].item() == pytest.approx(99.0, abs=1e-5) + assert v_out[0, 0, 7, 0].item() == pytest.approx(99.0, abs=1e-5) + + # Other positions must still be 0.0 + assert k_out[0, 0, 0, 0].item() == pytest.approx(0.0, abs=1e-5) + assert k_out[0, 0, 6, 0].item() == pytest.approx(0.0, abs=1e-5) + assert k_out[0, 0, 8, 0].item() == pytest.approx(0.0, abs=1e-5) + + def test_multiple_decode_steps_overwrite_correct_positions(self): + """Multiple decode steps must each overwrite the correct position.""" + cache = QEffDynamicCache() + batch, heads, ctx_len, head_dim = 1, 2, 16, 4 + + k_prefill = torch.zeros(batch, heads, ctx_len, head_dim) + v_prefill = torch.zeros(batch, heads, ctx_len, head_dim) + cache.update(k_prefill, v_prefill, layer_idx=0, cache_kwargs={"position_ids": pos_ids(seq=ctx_len)}) + + for pos, val in [(2, 10.0), (5, 20.0), (10, 30.0)]: + k_d = torch.ones(batch, heads, 1, head_dim) * val + v_d = torch.ones(batch, heads, 1, head_dim) * val + k_out, v_out = cache.update(k_d, v_d, layer_idx=0, cache_kwargs={"position_ids": torch.tensor([[pos]])}) + + # Final state: position 10 should have 30.0 + assert k_out[0, 0, 10, 0].item() == pytest.approx(30.0, abs=1e-5) + + def test_multi_layer_cache_independence(self): + """Different layers must not interfere with each other.""" + cache = QEffDynamicCache() + batch, heads, ctx_len, head_dim = 1, 2, 8, 4 + + for layer_idx in range(3): + k = torch.ones(batch, heads, ctx_len, head_dim) * float(layer_idx + 1) + v = torch.ones(batch, heads, ctx_len, head_dim) * float(layer_idx + 1) + cache.write_only(k, v, layer_idx=layer_idx, cache_kwargs={"position_ids": pos_ids(seq=ctx_len)}) + + for layer_idx in range(3): + k_out, v_out = cache.read_only(layer_idx=layer_idx, cache_kwargs={"position_ids": pos_ids(seq=ctx_len)}) + expected_val = float(layer_idx + 1) + assert k_out[0, 0, 0, 0].item() == pytest.approx(expected_val, abs=1e-5), ( + f"Layer {layer_idx} key value mismatch: expected {expected_val}, got {k_out[0, 0, 0, 0].item()}" + ) + + def test_decode_does_not_corrupt_prior_positions(self): + """A decode write at position N must not corrupt positions 0..N-1. + + Note: QEfficient's CtxScatter zeros out positions > decode_position + (they are not yet valid tokens). Only positions <= decode_position + are guaranteed to be preserved. + """ + cache = QEffDynamicCache() + batch, heads, ctx_len, head_dim = 1, 1, 8, 4 + + # Prefill with known sequential values + k_prefill = ( + torch.arange(ctx_len, dtype=torch.float32) + .reshape(1, 1, ctx_len, 1) + .expand(batch, heads, ctx_len, head_dim) + .clone() + ) + v_prefill = k_prefill.clone() + cache.update(k_prefill, v_prefill, layer_idx=0, cache_kwargs={"position_ids": pos_ids(seq=ctx_len)}) + + # Decode: overwrite position 4 with 999.0 + k_decode = torch.ones(batch, heads, 1, head_dim) * 999.0 + v_decode = torch.ones(batch, heads, 1, head_dim) * 999.0 + k_out, v_out = cache.update(k_decode, v_decode, layer_idx=0, cache_kwargs={"position_ids": torch.tensor([[4]])}) + + # Position 4 must be 999.0 + assert k_out[0, 0, 4, 0].item() == pytest.approx(999.0, abs=1e-5) + # Positions before the decode position must be preserved + assert k_out[0, 0, 3, 0].item() == pytest.approx(3.0, abs=1e-5) + assert k_out[0, 0, 0, 0].item() == pytest.approx(0.0, abs=1e-5) + assert k_out[0, 0, 1, 0].item() == pytest.approx(1.0, abs=1e-5) + assert k_out[0, 0, 2, 0].item() == pytest.approx(2.0, abs=1e-5) diff --git a/tests/unit_test/models/test_causal_lm_accuracy.py b/tests/unit_test/models/test_causal_lm_accuracy.py new file mode 100644 index 000000000..ccf455a3c --- /dev/null +++ b/tests/unit_test/models/test_causal_lm_accuracy.py @@ -0,0 +1,872 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +# SPDX-License-Identifier: BSD-3-Clause +# +# ----------------------------------------------------------------------------- +""" +Accuracy tests for CausalLM models: HF PyTorch → QEff PyTorch → ONNX structure. + +Improvements over unit_v2: + - Expanded model coverage: GPT2, Llama, Mistral, Qwen2, Phi3, Gemma, Gemma2, Falcon + - Continuous batching mode tests + - ONNX structure validation for all models + +Key accuracy assertions: + - HF and QEff produce the SAME greedy next token (argmax of last-token logits) + - HF and QEff logits are numerically close (softmax max_diff < 1e-3) + - Decode step produces valid tokens in range [0, vocab_size) + +All tests run on CPU only. +""" + +import pytest +import torch +import torch.nn.functional as F +from transformers import ( + FalconConfig, + FalconForCausalLM, + GemmaConfig, + GemmaForCausalLM, + GPT2Config, + GPT2LMHeadModel, + LlamaConfig, + LlamaForCausalLM, + MistralConfig, + MistralForCausalLM, + Phi3Config, + Phi3ForCausalLM, + Qwen2Config, + Qwen2ForCausalLM, +) + +from QEfficient.transformers.models.modeling_auto import QEFFAutoModelForCausalLM + +CTX_LEN = 32 +SEQ_LEN = 8 +VOCAB_SIZE = 500 + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _get_dims(config): + """Extract (n_layers, n_kv_heads, head_dim) from any config.""" + if hasattr(config, "num_hidden_layers"): + n_layers = config.num_hidden_layers + n_attn = config.num_attention_heads + n_kv = getattr(config, "num_key_value_heads", n_attn) + head_dim = getattr(config, "head_dim", None) or (config.hidden_size // n_attn) + else: + n_layers = config.n_layer + n_attn = config.n_head + n_kv = config.n_head + head_dim = config.n_embd // n_attn + return n_layers, n_kv, head_dim + + +def make_qeff_inputs(input_ids, config, ctx_len=CTX_LEN): + """Build QEff-style inputs: input_ids + position_ids + zero-init past_key_values.""" + batch, seq = input_ids.shape + position_ids = torch.arange(seq).unsqueeze(0).expand(batch, -1) + n_layers, n_kv, head_dim = _get_dims(config) + past_key_values = tuple( + ( + torch.zeros(batch, n_kv, ctx_len, head_dim, dtype=torch.float32), + torch.zeros(batch, n_kv, ctx_len, head_dim, dtype=torch.float32), + ) + for _ in range(n_layers) + ) + return {"input_ids": input_ids, "position_ids": position_ids, "past_key_values": past_key_values} + + +# --------------------------------------------------------------------------- +# Tiny model factories +# --------------------------------------------------------------------------- + + +def make_tiny_gpt2(): + cfg = GPT2Config(n_layer=2, n_head=2, n_embd=64, vocab_size=VOCAB_SIZE, n_positions=CTX_LEN, n_ctx=CTX_LEN) + return GPT2LMHeadModel(cfg).eval(), cfg + + +def make_tiny_llama(): + cfg = LlamaConfig( + num_hidden_layers=2, + num_attention_heads=2, + num_key_value_heads=2, + hidden_size=64, + intermediate_size=128, + vocab_size=VOCAB_SIZE, + max_position_embeddings=CTX_LEN, + ) + return LlamaForCausalLM(cfg).eval(), cfg + + +def make_tiny_mistral(): + cfg = MistralConfig( + num_hidden_layers=2, + num_attention_heads=2, + num_key_value_heads=2, + hidden_size=64, + intermediate_size=128, + vocab_size=VOCAB_SIZE, + max_position_embeddings=CTX_LEN, + ) + return MistralForCausalLM(cfg).eval(), cfg + + +def make_tiny_qwen2(): + cfg = Qwen2Config( + num_hidden_layers=2, + num_attention_heads=2, + num_key_value_heads=2, + hidden_size=64, + intermediate_size=128, + vocab_size=VOCAB_SIZE, + max_position_embeddings=CTX_LEN, + ) + return Qwen2ForCausalLM(cfg).eval(), cfg + + +def make_tiny_phi3(): + cfg = Phi3Config( + num_hidden_layers=2, + num_attention_heads=2, + num_key_value_heads=2, + hidden_size=64, + intermediate_size=128, + vocab_size=VOCAB_SIZE, + max_position_embeddings=CTX_LEN, + pad_token_id=0, + ) + return Phi3ForCausalLM(cfg).eval(), cfg + + +def make_tiny_gemma(): + cfg = GemmaConfig( + num_hidden_layers=2, + num_attention_heads=2, + num_key_value_heads=2, + hidden_size=64, + intermediate_size=128, + vocab_size=VOCAB_SIZE, + max_position_embeddings=CTX_LEN, + head_dim=32, + ) + return GemmaForCausalLM(cfg).eval(), cfg + + +def make_tiny_falcon(): + cfg = FalconConfig( + num_hidden_layers=2, + num_attention_heads=2, + hidden_size=64, + vocab_size=VOCAB_SIZE, + max_position_embeddings=CTX_LEN, + new_decoder_architecture=False, + multi_query=True, + ) + return FalconForCausalLM(cfg).eval(), cfg + + +# --------------------------------------------------------------------------- +# Stage 1: HF PyTorch baseline +# --------------------------------------------------------------------------- + + +@pytest.mark.causal_lm +class TestHFCausalLMBaseline: + """HF models run correctly on CPU and produce valid logits.""" + + def _check_logits_shape(self, factory, label): + model, cfg = factory() + input_ids = torch.randint(0, VOCAB_SIZE, (1, SEQ_LEN)) + with torch.no_grad(): + out = model(input_ids=input_ids) + assert out.logits.shape == (1, SEQ_LEN, VOCAB_SIZE), ( + f"[{label}] Expected logits shape (1, {SEQ_LEN}, {VOCAB_SIZE}), got {out.logits.shape}" + ) + + def test_gpt2_forward_returns_logits_with_correct_shape(self): + self._check_logits_shape(make_tiny_gpt2, "GPT2") + + def test_llama_forward_returns_logits_with_correct_shape(self): + self._check_logits_shape(make_tiny_llama, "Llama") + + def test_mistral_forward_returns_logits_with_correct_shape(self): + self._check_logits_shape(make_tiny_mistral, "Mistral") + + def test_qwen2_forward_returns_logits_with_correct_shape(self): + self._check_logits_shape(make_tiny_qwen2, "Qwen2") + + def test_phi3_forward_returns_logits_with_correct_shape(self): + self._check_logits_shape(make_tiny_phi3, "Phi3") + + def test_gemma_forward_returns_logits_with_correct_shape(self): + self._check_logits_shape(make_tiny_gemma, "Gemma") + + def test_falcon_forward_returns_logits_with_correct_shape(self): + self._check_logits_shape(make_tiny_falcon, "Falcon") + + def test_hf_logits_are_finite(self): + """HF logits must not contain NaN or Inf for any model.""" + for factory, label in [ + (make_tiny_gpt2, "GPT2"), + (make_tiny_llama, "Llama"), + (make_tiny_mistral, "Mistral"), + (make_tiny_qwen2, "Qwen2"), + (make_tiny_phi3, "Phi3"), + (make_tiny_gemma, "Gemma"), + ]: + model, cfg = factory() + input_ids = torch.randint(0, VOCAB_SIZE, (1, SEQ_LEN)) + with torch.no_grad(): + logits = model(input_ids=input_ids).logits + assert torch.isfinite(logits).all(), f"[{label}] HF logits contain NaN/Inf" + + def test_gpt2_greedy_decode_is_deterministic(self): + model, cfg = make_tiny_gpt2() + input_ids = torch.randint(0, VOCAB_SIZE, (1, SEQ_LEN)) + with torch.no_grad(): + t1 = model(input_ids=input_ids).logits[:, -1, :].argmax(-1).item() + t2 = model(input_ids=input_ids).logits[:, -1, :].argmax(-1).item() + assert t1 == t2, "Greedy decode must be deterministic" + + +# --------------------------------------------------------------------------- +# Stage 2: QEff PyTorch accuracy vs HF +# --------------------------------------------------------------------------- + + +@pytest.mark.causal_lm +@pytest.mark.accuracy +class TestQEffCausalLMAccuracyVsHF: + """ + QEff KV-transformed model must produce the same greedy next token as HF. + This is the primary regression test: if KVCacheTransform or CustomOpsTransform + changes the model's numerical output, these tests will catch it. + """ + + def _assert_same_greedy_token(self, model, cfg, label): + input_ids = torch.randint(0, VOCAB_SIZE, (1, SEQ_LEN)) + + with torch.no_grad(): + hf_logits = model(input_ids=input_ids).logits[:, -1, :] + hf_token = hf_logits.argmax(-1).item() + + qeff_model = QEFFAutoModelForCausalLM(model) + qeff_inputs = make_qeff_inputs(input_ids, cfg) + with torch.no_grad(): + qeff_logits = qeff_model.model(**qeff_inputs).logits[:, -1, :] + qeff_token = qeff_logits.argmax(-1).item() + + assert hf_token == qeff_token, ( + f"[{label}] Greedy token mismatch: HF={hf_token}, QEff={qeff_token}. " + f"KVCacheTransform must not change the model's greedy prediction." + ) + + def _assert_logits_numerically_close(self, model, cfg, label, atol=1e-3): + input_ids = torch.randint(0, VOCAB_SIZE, (1, SEQ_LEN)) + + with torch.no_grad(): + hf_logits = model(input_ids=input_ids).logits[:, -1, :] + + qeff_model = QEFFAutoModelForCausalLM(model) + qeff_inputs = make_qeff_inputs(input_ids, cfg) + with torch.no_grad(): + qeff_logits = qeff_model.model(**qeff_inputs).logits[:, -1, :] + + hf_probs = F.softmax(hf_logits, dim=-1) + qeff_probs = F.softmax(qeff_logits, dim=-1) + max_diff = (hf_probs - qeff_probs).abs().max().item() + assert max_diff < atol, f"[{label}] Probability distribution mismatch: max_diff={max_diff:.6f} > atol={atol}." + + def test_gpt2_qeff_matches_hf_greedy_token(self): + model, cfg = make_tiny_gpt2() + self._assert_same_greedy_token(model, cfg, "GPT2") + + def test_llama_qeff_matches_hf_greedy_token(self): + model, cfg = make_tiny_llama() + self._assert_same_greedy_token(model, cfg, "Llama") + + def test_mistral_qeff_matches_hf_greedy_token(self): + model, cfg = make_tiny_mistral() + self._assert_same_greedy_token(model, cfg, "Mistral") + + def test_qwen2_qeff_matches_hf_greedy_token(self): + model, cfg = make_tiny_qwen2() + self._assert_same_greedy_token(model, cfg, "Qwen2") + + def test_phi3_qeff_matches_hf_greedy_token(self): + model, cfg = make_tiny_phi3() + self._assert_same_greedy_token(model, cfg, "Phi3") + + def test_gemma_qeff_matches_hf_greedy_token(self): + model, cfg = make_tiny_gemma() + self._assert_same_greedy_token(model, cfg, "Gemma") + + def test_gpt2_qeff_logits_numerically_close_to_hf(self): + model, cfg = make_tiny_gpt2() + self._assert_logits_numerically_close(model, cfg, "GPT2") + + def test_llama_qeff_logits_numerically_close_to_hf(self): + model, cfg = make_tiny_llama() + self._assert_logits_numerically_close(model, cfg, "Llama") + + def test_mistral_qeff_logits_numerically_close_to_hf(self): + model, cfg = make_tiny_mistral() + self._assert_logits_numerically_close(model, cfg, "Mistral") + + def test_qwen2_qeff_logits_numerically_close_to_hf(self): + model, cfg = make_tiny_qwen2() + self._assert_logits_numerically_close(model, cfg, "Qwen2") + + def test_phi3_qeff_logits_numerically_close_to_hf(self): + model, cfg = make_tiny_phi3() + self._assert_logits_numerically_close(model, cfg, "Phi3") + + def test_qeff_logits_are_finite(self): + """QEff logits must not contain NaN or Inf for any model.""" + for factory, label in [ + (make_tiny_gpt2, "GPT2"), + (make_tiny_llama, "Llama"), + (make_tiny_mistral, "Mistral"), + (make_tiny_qwen2, "Qwen2"), + (make_tiny_phi3, "Phi3"), + ]: + model, cfg = factory() + qeff_model = QEFFAutoModelForCausalLM(model) + input_ids = torch.randint(0, VOCAB_SIZE, (1, SEQ_LEN)) + qeff_inputs = make_qeff_inputs(input_ids, cfg) + with torch.no_grad(): + logits = qeff_model.model(**qeff_inputs).logits + assert torch.isfinite(logits).all(), f"[{label}] QEff logits contain NaN/Inf" + + def test_qeff_past_key_values_returned(self): + """QEff model must return past_key_values for the decode step.""" + model, cfg = make_tiny_gpt2() + qeff_model = QEFFAutoModelForCausalLM(model) + input_ids = torch.randint(0, VOCAB_SIZE, (1, SEQ_LEN)) + qeff_inputs = make_qeff_inputs(input_ids, cfg) + with torch.no_grad(): + out = qeff_model.model(**qeff_inputs) + assert out.past_key_values is not None, "QEff model must return past_key_values" + + def test_gpt2_top5_tokens_overlap_with_hf(self): + """Top-5 predicted tokens must overlap between HF and QEff.""" + model, cfg = make_tiny_gpt2() + input_ids = torch.randint(0, VOCAB_SIZE, (1, SEQ_LEN)) + + with torch.no_grad(): + hf_top5 = set(model(input_ids=input_ids).logits[:, -1, :].topk(5).indices.squeeze().tolist()) + + qeff_model = QEFFAutoModelForCausalLM(model) + qeff_inputs = make_qeff_inputs(input_ids, cfg) + with torch.no_grad(): + qeff_top5 = set(qeff_model.model(**qeff_inputs).logits[:, -1, :].topk(5).indices.squeeze().tolist()) + + overlap = len(hf_top5 & qeff_top5) + assert overlap >= 4, f"Top-5 token overlap too low: {overlap}/5. HF={hf_top5}, QEff={qeff_top5}" + + +# --------------------------------------------------------------------------- +# Stage 2b: Decode step accuracy +# --------------------------------------------------------------------------- + + +@pytest.mark.causal_lm +@pytest.mark.accuracy +class TestQEffDecodeStepAccuracy: + """Decode step must produce consistent, finite tokens.""" + + def _run_prefill_then_decode(self, model, cfg, n_decode_steps=3, input_ids=None): + """Run prefill + n decode steps, return list of generated token IDs.""" + qeff_model = QEFFAutoModelForCausalLM(model) + if input_ids is None: + input_ids = torch.randint(0, VOCAB_SIZE, (1, SEQ_LEN)) + qeff_inputs = make_qeff_inputs(input_ids, cfg) + + generated = [] + with torch.no_grad(): + out = qeff_model.model(**qeff_inputs) + next_token = out.logits[:, -1, :].argmax(-1).item() + generated.append(next_token) + prev_pos = SEQ_LEN - 1 + + for _ in range(n_decode_steps - 1): + n_layers, n_kv, head_dim = _get_dims(cfg) + decode_inputs = { + "input_ids": torch.tensor([[next_token]], dtype=torch.long), + "position_ids": torch.tensor([[prev_pos + 1]], dtype=torch.long), + "past_key_values": tuple( + ( + torch.zeros(1, n_kv, CTX_LEN, head_dim, dtype=torch.float32), + torch.zeros(1, n_kv, CTX_LEN, head_dim, dtype=torch.float32), + ) + for _ in range(n_layers) + ), + } + out = qeff_model.model(**decode_inputs) + next_token = out.logits[:, -1, :].argmax(-1).item() + generated.append(next_token) + prev_pos += 1 + + return generated + + def test_gpt2_decode_produces_valid_tokens(self): + model, cfg = make_tiny_gpt2() + tokens = self._run_prefill_then_decode(model, cfg, n_decode_steps=3) + assert len(tokens) == 3 + assert all(0 <= t < VOCAB_SIZE for t in tokens), f"Invalid token IDs: {tokens}" + + def test_llama_decode_produces_valid_tokens(self): + model, cfg = make_tiny_llama() + tokens = self._run_prefill_then_decode(model, cfg, n_decode_steps=3) + assert len(tokens) == 3 + assert all(0 <= t < VOCAB_SIZE for t in tokens), f"Invalid token IDs: {tokens}" + + def test_mistral_decode_produces_valid_tokens(self): + model, cfg = make_tiny_mistral() + tokens = self._run_prefill_then_decode(model, cfg, n_decode_steps=3) + assert len(tokens) == 3 + assert all(0 <= t < VOCAB_SIZE for t in tokens), f"Invalid token IDs: {tokens}" + + def test_phi3_decode_produces_valid_tokens(self): + model, cfg = make_tiny_phi3() + tokens = self._run_prefill_then_decode(model, cfg, n_decode_steps=3) + assert len(tokens) == 3 + assert all(0 <= t < VOCAB_SIZE for t in tokens), f"Invalid token IDs: {tokens}" + + def test_gpt2_prefill_token_matches_hf_next_token(self): + """The first token from QEff prefill must match HF's greedy next token.""" + model, cfg = make_tiny_gpt2() + input_ids = torch.randint(0, VOCAB_SIZE, (1, SEQ_LEN)) + + with torch.no_grad(): + hf_next = model(input_ids=input_ids).logits[:, -1, :].argmax(-1).item() + + qeff_model = QEFFAutoModelForCausalLM(model) + qeff_inputs = make_qeff_inputs(input_ids, cfg) + with torch.no_grad(): + qeff_next = qeff_model.model(**qeff_inputs).logits[:, -1, :].argmax(-1).item() + + assert hf_next == qeff_next, f"Prefill next token mismatch: HF={hf_next}, QEff={qeff_next}" + + def test_llama_prefill_token_matches_hf_next_token(self): + model, cfg = make_tiny_llama() + input_ids = torch.randint(0, VOCAB_SIZE, (1, SEQ_LEN)) + + with torch.no_grad(): + hf_next = model(input_ids=input_ids).logits[:, -1, :].argmax(-1).item() + + qeff_model = QEFFAutoModelForCausalLM(model) + qeff_inputs = make_qeff_inputs(input_ids, cfg) + with torch.no_grad(): + qeff_next = qeff_model.model(**qeff_inputs).logits[:, -1, :].argmax(-1).item() + + assert hf_next == qeff_next, f"Prefill next token mismatch: HF={hf_next}, QEff={qeff_next}" + + def test_gpt2_decode_is_deterministic(self): + """Same model + same input must produce the same decode sequence.""" + import copy + + model, cfg = make_tiny_gpt2() + model_copy = copy.deepcopy(model) + input_ids = torch.randint(0, VOCAB_SIZE, (1, SEQ_LEN)) + tokens1 = self._run_prefill_then_decode(model, cfg, n_decode_steps=3, input_ids=input_ids) + tokens2 = self._run_prefill_then_decode(model_copy, cfg, n_decode_steps=3, input_ids=input_ids) + assert tokens1 == tokens2, f"Decode is not deterministic: {tokens1} vs {tokens2}" + + +# --------------------------------------------------------------------------- +# Stage 2c: Continuous batching mode +# --------------------------------------------------------------------------- + + +@pytest.mark.causal_lm +class TestContinuousBatchingMode: + """ + QEFFAutoModelForCausalLM with continuous_batching=True must wrap correctly + and produce valid outputs. + """ + + def test_gpt2_continuous_batching_wraps_without_error(self): + model, cfg = make_tiny_gpt2() + qeff = QEFFAutoModelForCausalLM(model, continuous_batching=True) + assert qeff is not None + assert qeff.continuous_batching is True + + def test_llama_continuous_batching_wraps_without_error(self): + model, cfg = make_tiny_llama() + qeff = QEFFAutoModelForCausalLM(model, continuous_batching=True) + assert qeff is not None + assert qeff.continuous_batching is True + + def test_gpt2_continuous_batching_model_is_transformed(self): + """With continuous_batching=True, the model must still be KV-transformed.""" + from QEfficient.transformers.models.gpt2.modeling_gpt2 import QEffGPT2LMHeadModel + + model, cfg = make_tiny_gpt2() + qeff = QEFFAutoModelForCausalLM(model, continuous_batching=True) + assert isinstance(qeff.model, QEffGPT2LMHeadModel) + + def test_continuous_batching_false_is_default(self): + model, cfg = make_tiny_gpt2() + qeff = QEFFAutoModelForCausalLM(model) + assert qeff.continuous_batching is False + + def test_continuous_batching_model_produces_finite_logits(self): + """Continuous batching model must produce finite logits.""" + model, cfg = make_tiny_llama() + qeff = QEFFAutoModelForCausalLM(model, continuous_batching=True) + input_ids = torch.randint(0, VOCAB_SIZE, (1, SEQ_LEN)) + qeff_inputs = make_qeff_inputs(input_ids, cfg) + with torch.no_grad(): + out = qeff.model(**qeff_inputs) + assert torch.isfinite(out.logits).all() + + +# --------------------------------------------------------------------------- +# Stage 3: ONNX export structure +# --------------------------------------------------------------------------- + + +@pytest.mark.causal_lm +@pytest.mark.onnx +@pytest.mark.slow +class TestCausalLMONNXStructure: + """ + ONNX export must produce valid models with correct KV cache inputs/outputs. + """ + + def _check_onnx_export(self, factory, label, tmp_export_dir): + import os + + model, cfg = factory() + qeff_model = QEFFAutoModelForCausalLM(model) + onnx_path = qeff_model.export(export_dir=str(tmp_export_dir)) + assert onnx_path is not None, f"[{label}] ONNX export returned None" + assert os.path.exists(str(onnx_path)), f"[{label}] ONNX file does not exist" + assert os.path.getsize(str(onnx_path)) > 0, f"[{label}] ONNX file is empty" + return onnx_path + + def test_gpt2_onnx_export_succeeds(self, tmp_export_dir): + self._check_onnx_export(make_tiny_gpt2, "GPT2", tmp_export_dir) + + def test_llama_onnx_export_succeeds(self, tmp_export_dir): + self._check_onnx_export(make_tiny_llama, "Llama", tmp_export_dir) + + def test_mistral_onnx_export_succeeds(self, tmp_export_dir): + self._check_onnx_export(make_tiny_mistral, "Mistral", tmp_export_dir) + + def test_qwen2_onnx_export_succeeds(self, tmp_export_dir): + self._check_onnx_export(make_tiny_qwen2, "Qwen2", tmp_export_dir) + + def test_phi3_onnx_export_succeeds(self, tmp_export_dir): + self._check_onnx_export(make_tiny_phi3, "Phi3", tmp_export_dir) + + def test_gpt2_onnx_passes_checker(self, tmp_export_dir): + import onnx + + onnx_path = self._check_onnx_export(make_tiny_gpt2, "GPT2", tmp_export_dir) + onnx_model = onnx.load(str(onnx_path)) + onnx.checker.check_model(onnx_model) + + def test_llama_onnx_passes_checker(self, tmp_export_dir): + import onnx + + onnx_path = self._check_onnx_export(make_tiny_llama, "Llama", tmp_export_dir) + onnx_model = onnx.load(str(onnx_path)) + onnx.checker.check_model(onnx_model) + + def test_gpt2_onnx_has_input_ids_and_position_ids(self, tmp_export_dir): + import onnx + + onnx_path = self._check_onnx_export(make_tiny_gpt2, "GPT2", tmp_export_dir) + onnx_model = onnx.load(str(onnx_path)) + input_names = {inp.name for inp in onnx_model.graph.input} + assert "input_ids" in input_names, f"input_ids missing from ONNX inputs: {input_names}" + assert "position_ids" in input_names, f"position_ids missing from ONNX inputs: {input_names}" + + def test_gpt2_onnx_has_kv_cache_inputs_for_all_layers(self, tmp_export_dir): + import onnx + + n_layers = 2 + onnx_path = self._check_onnx_export(make_tiny_gpt2, "GPT2", tmp_export_dir) + onnx_model = onnx.load(str(onnx_path)) + input_names = {inp.name for inp in onnx_model.graph.input} + for i in range(n_layers): + assert f"past_key.{i}" in input_names, f"past_key.{i} missing from ONNX inputs" + assert f"past_value.{i}" in input_names, f"past_value.{i} missing from ONNX inputs" + + def test_llama_onnx_has_kv_cache_inputs_for_all_layers(self, tmp_export_dir): + import onnx + + n_layers = 2 + onnx_path = self._check_onnx_export(make_tiny_llama, "Llama", tmp_export_dir) + onnx_model = onnx.load(str(onnx_path)) + input_names = {inp.name for inp in onnx_model.graph.input} + for i in range(n_layers): + assert f"past_key.{i}" in input_names, f"past_key.{i} missing from ONNX inputs" + assert f"past_value.{i}" in input_names, f"past_value.{i} missing from ONNX inputs" + + def test_gpt2_onnx_has_logits_output(self, tmp_export_dir): + import onnx + + onnx_path = self._check_onnx_export(make_tiny_gpt2, "GPT2", tmp_export_dir) + onnx_model = onnx.load(str(onnx_path)) + output_names = {out.name for out in onnx_model.graph.output} + assert "logits" in output_names, f"logits missing from ONNX outputs: {output_names}" + + def test_gpt2_onnx_has_retained_state_outputs(self, tmp_export_dir): + """KV cache outputs must be present as RetainedState outputs.""" + import onnx + + onnx_path = self._check_onnx_export(make_tiny_gpt2, "GPT2", tmp_export_dir) + onnx_model = onnx.load(str(onnx_path)) + output_names = [out.name for out in onnx_model.graph.output] + retained = [n for n in output_names if "RetainedState" in n] + assert len(retained) > 0, f"No RetainedState outputs found: {output_names}" + + def test_gpt2_onnx_uses_correct_opset_version(self, tmp_export_dir): + """Exported ONNX must use the opset version defined in QEfficient constants.""" + import onnx + + from QEfficient.utils.constants import ONNX_EXPORT_OPSET + + onnx_path = self._check_onnx_export(make_tiny_gpt2, "GPT2", tmp_export_dir) + onnx_model = onnx.load(str(onnx_path)) + opset_versions = [op.version for op in onnx_model.opset_import] + assert ONNX_EXPORT_OPSET in opset_versions, ( + f"Expected opset {ONNX_EXPORT_OPSET} in ONNX opset_import, got {opset_versions}" + ) + + def test_gpt2_ort_session_creation_succeeds(self, tmp_export_dir): + """ORT session must be creatable from the exported ONNX.""" + import onnxruntime as ort + + onnx_path = self._check_onnx_export(make_tiny_gpt2, "GPT2", tmp_export_dir) + session = ort.InferenceSession(str(onnx_path), providers=["CPUExecutionProvider"]) + assert session is not None + ort_inputs = {inp.name for inp in session.get_inputs()} + assert "input_ids" in ort_inputs + assert "position_ids" in ort_inputs + + def _check_ort_prefill_accuracy(self, factory, label, tmp_export_dir): + """ + Export model with SUBFUNC_ENABLED, run ORT prefill, return + (pt_logits_last, ort_logits_last, session, output_names, input_ids, cfg). + + ORT cannot handle INT32_MAX as a GatherND index (the default sentinel used during + ONNX export). Subfunc mode substitutes 0 instead, which is a valid index and + produces numerically identical results because those positions are masked out + afterward by the attention mask. + """ + import numpy as np + import onnxruntime as ort + + from QEfficient.transformers.cache_utils import InvalidIndexProvider + + model, cfg = factory() + qeff_model = QEFFAutoModelForCausalLM(model) + + InvalidIndexProvider.SUBFUNC_ENABLED = True + try: + onnx_path = qeff_model.export(export_dir=str(tmp_export_dir), offload_pt_weights=False) + finally: + InvalidIndexProvider.SUBFUNC_ENABLED = False + + input_ids = torch.randint(0, VOCAB_SIZE, (1, SEQ_LEN)) + qeff_inputs = make_qeff_inputs(input_ids, cfg) + with torch.no_grad(): + pt_logits = qeff_model.model(**qeff_inputs).logits[:, -1, :].numpy() + + session = ort.InferenceSession(str(onnx_path), providers=["CPUExecutionProvider"]) + n_layers, n_kv, head_dim = _get_dims(cfg) + ort_inputs = { + "input_ids": input_ids.numpy(), + "position_ids": torch.arange(SEQ_LEN).unsqueeze(0).numpy(), + } + for i in range(n_layers): + ort_inputs[f"past_key.{i}"] = np.zeros((1, n_kv, CTX_LEN, head_dim), dtype=np.float32) + ort_inputs[f"past_value.{i}"] = np.zeros((1, n_kv, CTX_LEN, head_dim), dtype=np.float32) + + output_names = [o.name for o in session.get_outputs()] + ort_out = dict(zip(output_names, session.run(output_names, ort_inputs))) + ort_logits = ort_out["logits"][:, -1, :] + + return pt_logits, ort_logits, session, output_names, input_ids, cfg + + def test_gpt2_ort_prefill_produces_correct_logits(self, tmp_export_dir): + """ORT prefill must produce logits matching QEff PyTorch.""" + pt_logits, ort_logits, _, _, _, _ = self._check_ort_prefill_accuracy(make_tiny_gpt2, "GPT2", tmp_export_dir) + pt_token = int(pt_logits.argmax(-1)) + ort_token = int(ort_logits.argmax(-1)) + assert pt_token == ort_token, f"Token mismatch: PyTorch={pt_token}, ORT={ort_token}" + + def test_llama_ort_session_creation_succeeds(self, tmp_export_dir): + """ORT session must be creatable from the exported Llama ONNX.""" + import onnxruntime as ort + + from QEfficient.transformers.cache_utils import InvalidIndexProvider + + model, cfg = make_tiny_llama() + qeff_model = QEFFAutoModelForCausalLM(model) + InvalidIndexProvider.SUBFUNC_ENABLED = True + try: + onnx_path = qeff_model.export(export_dir=str(tmp_export_dir), offload_pt_weights=False) + finally: + InvalidIndexProvider.SUBFUNC_ENABLED = False + session = ort.InferenceSession(str(onnx_path), providers=["CPUExecutionProvider"]) + assert session is not None + ort_inputs = {inp.name for inp in session.get_inputs()} + assert "input_ids" in ort_inputs + assert "position_ids" in ort_inputs + + def test_mistral_ort_session_creation_succeeds(self, tmp_export_dir): + """ORT session must be creatable from the exported Mistral ONNX.""" + import onnxruntime as ort + + from QEfficient.transformers.cache_utils import InvalidIndexProvider + + model, cfg = make_tiny_mistral() + qeff_model = QEFFAutoModelForCausalLM(model) + InvalidIndexProvider.SUBFUNC_ENABLED = True + try: + onnx_path = qeff_model.export(export_dir=str(tmp_export_dir), offload_pt_weights=False) + finally: + InvalidIndexProvider.SUBFUNC_ENABLED = False + session = ort.InferenceSession(str(onnx_path), providers=["CPUExecutionProvider"]) + assert session is not None + ort_inputs = {inp.name for inp in session.get_inputs()} + assert "input_ids" in ort_inputs + assert "position_ids" in ort_inputs + + def test_qwen2_ort_session_creation_succeeds(self, tmp_export_dir): + """ORT session must be creatable from the exported Qwen2 ONNX.""" + import onnxruntime as ort + + from QEfficient.transformers.cache_utils import InvalidIndexProvider + + model, cfg = make_tiny_qwen2() + qeff_model = QEFFAutoModelForCausalLM(model) + InvalidIndexProvider.SUBFUNC_ENABLED = True + try: + onnx_path = qeff_model.export(export_dir=str(tmp_export_dir), offload_pt_weights=False) + finally: + InvalidIndexProvider.SUBFUNC_ENABLED = False + session = ort.InferenceSession(str(onnx_path), providers=["CPUExecutionProvider"]) + assert session is not None + ort_inputs = {inp.name for inp in session.get_inputs()} + assert "input_ids" in ort_inputs + assert "position_ids" in ort_inputs + + def test_phi3_ort_session_creation_succeeds(self, tmp_export_dir): + """ORT session must be creatable from the exported Phi3 ONNX.""" + import onnxruntime as ort + + from QEfficient.transformers.cache_utils import InvalidIndexProvider + + model, cfg = make_tiny_phi3() + qeff_model = QEFFAutoModelForCausalLM(model) + InvalidIndexProvider.SUBFUNC_ENABLED = True + try: + onnx_path = qeff_model.export(export_dir=str(tmp_export_dir), offload_pt_weights=False) + finally: + InvalidIndexProvider.SUBFUNC_ENABLED = False + session = ort.InferenceSession(str(onnx_path), providers=["CPUExecutionProvider"]) + assert session is not None + ort_inputs = {inp.name for inp in session.get_inputs()} + assert "input_ids" in ort_inputs + assert "position_ids" in ort_inputs + + def test_llama_ort_prefill_produces_correct_logits(self, tmp_export_dir): + """ORT Llama prefill must produce logits matching QEff PyTorch.""" + pt_logits, ort_logits, _, _, _, _ = self._check_ort_prefill_accuracy(make_tiny_llama, "Llama", tmp_export_dir) + pt_token = int(pt_logits.argmax(-1)) + ort_token = int(ort_logits.argmax(-1)) + assert pt_token == ort_token, f"[Llama] Token mismatch: PyTorch={pt_token}, ORT={ort_token}" + + def test_mistral_ort_prefill_produces_correct_logits(self, tmp_export_dir): + """ORT Mistral prefill must produce logits matching QEff PyTorch.""" + pt_logits, ort_logits, _, _, _, _ = self._check_ort_prefill_accuracy( + make_tiny_mistral, "Mistral", tmp_export_dir + ) + pt_token = int(pt_logits.argmax(-1)) + ort_token = int(ort_logits.argmax(-1)) + assert pt_token == ort_token, f"[Mistral] Token mismatch: PyTorch={pt_token}, ORT={ort_token}" + + def test_qwen2_ort_prefill_produces_correct_logits(self, tmp_export_dir): + """ORT Qwen2 prefill must produce logits matching QEff PyTorch.""" + pt_logits, ort_logits, _, _, _, _ = self._check_ort_prefill_accuracy(make_tiny_qwen2, "Qwen2", tmp_export_dir) + pt_token = int(pt_logits.argmax(-1)) + ort_token = int(ort_logits.argmax(-1)) + assert pt_token == ort_token, f"[Qwen2] Token mismatch: PyTorch={pt_token}, ORT={ort_token}" + + def test_phi3_ort_prefill_produces_correct_logits(self, tmp_export_dir): + """ORT Phi3 prefill must produce logits matching QEff PyTorch.""" + pt_logits, ort_logits, _, _, _, _ = self._check_ort_prefill_accuracy(make_tiny_phi3, "Phi3", tmp_export_dir) + pt_token = int(pt_logits.argmax(-1)) + ort_token = int(ort_logits.argmax(-1)) + assert pt_token == ort_token, f"[Phi3] Token mismatch: PyTorch={pt_token}, ORT={ort_token}" + + def test_gpt2_ort_logits_are_finite(self, tmp_export_dir): + """ORT logits must not contain NaN or Inf.""" + import numpy as np + + _, ort_logits, _, _, _, _ = self._check_ort_prefill_accuracy(make_tiny_gpt2, "GPT2", tmp_export_dir) + assert np.isfinite(ort_logits).all(), "ORT GPT2 logits contain NaN/Inf" + + def test_gpt2_ort_output_shape_is_correct(self, tmp_export_dir): + """ORT logits shape must be (batch, seq_len, vocab_size) where seq_len matches input.""" + import numpy as np + import onnxruntime as ort + + from QEfficient.transformers.cache_utils import InvalidIndexProvider + + model, cfg = make_tiny_gpt2() + qeff_model = QEFFAutoModelForCausalLM(model) + InvalidIndexProvider.SUBFUNC_ENABLED = True + try: + onnx_path = qeff_model.export(export_dir=str(tmp_export_dir), offload_pt_weights=False) + finally: + InvalidIndexProvider.SUBFUNC_ENABLED = False + + input_ids = torch.randint(0, VOCAB_SIZE, (1, SEQ_LEN)) + session = ort.InferenceSession(str(onnx_path), providers=["CPUExecutionProvider"]) + n_layers, n_kv, head_dim = _get_dims(cfg) + ort_inputs = { + "input_ids": input_ids.numpy(), + "position_ids": torch.arange(SEQ_LEN).unsqueeze(0).numpy(), + } + for i in range(n_layers): + ort_inputs[f"past_key.{i}"] = np.zeros((1, n_kv, CTX_LEN, head_dim), dtype=np.float32) + ort_inputs[f"past_value.{i}"] = np.zeros((1, n_kv, CTX_LEN, head_dim), dtype=np.float32) + + output_names = [o.name for o in session.get_outputs()] + ort_out = dict(zip(output_names, session.run(output_names, ort_inputs))) + logits = ort_out["logits"] + # ORT model returns logits with shape (batch, actual_seq_len, vocab_size) + # where actual_seq_len may be 1 (last token only) or match input seq_len + assert logits.shape[0] == 1, f"Expected batch size 1, got {logits.shape[0]}" + assert logits.shape[2] == VOCAB_SIZE, f"Expected vocab size {VOCAB_SIZE}, got {logits.shape[2]}" + assert logits.shape[1] in [1, SEQ_LEN], f"Expected seq_len to be 1 or {SEQ_LEN}, got {logits.shape[1]}" + + def test_gpt2_ort_kv_cache_outputs_present(self, tmp_export_dir): + """ORT outputs must include RetainedState KV cache entries.""" + _, _, session, output_names, _, _ = self._check_ort_prefill_accuracy(make_tiny_gpt2, "GPT2", tmp_export_dir) + retained = [n for n in output_names if "RetainedState" in n] + assert len(retained) > 0, f"No RetainedState outputs in ORT session: {output_names}" + + def test_gpt2_ort_logits_numerically_close_to_pytorch(self, tmp_export_dir): + """ORT and PyTorch softmax distributions must be close (max_diff < 1e-3).""" + import numpy as np + + pt_logits, ort_logits, _, _, _, _ = self._check_ort_prefill_accuracy(make_tiny_gpt2, "GPT2", tmp_export_dir) + pt_probs = torch.tensor(pt_logits).softmax(-1).numpy() + ort_probs = torch.tensor(ort_logits).softmax(-1).numpy() + max_diff = float(np.abs(pt_probs - ort_probs).max()) + assert max_diff < 1e-3, f"ORT vs PyTorch softmax max_diff={max_diff:.6f} exceeds 1e-3" diff --git a/tests/unit_test/models/test_gemma2_accuracy.py b/tests/unit_test/models/test_gemma2_accuracy.py new file mode 100644 index 000000000..29a48616e --- /dev/null +++ b/tests/unit_test/models/test_gemma2_accuracy.py @@ -0,0 +1,565 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +# SPDX-License-Identifier: BSD-3-Clause +# +# ----------------------------------------------------------------------------- +""" + +Gemma2 is architecturally distinct from all other tested models: + 1. Uses QEffHybridCache (not QEffDynamicCache) — completely different cache class + 2. QEffGemma2ForCausalLM.forward() uses: + logit_index = position_ids.to(torch.int32).argmax(1, keepdim=True) + hidden_states = outputs[0][arange, logit_index] + → returns logits of shape (batch, 1, vocab), NOT (batch, seq, vocab) + 3. Has final_logit_softcapping (tanh-based logit capping) + 4. Has sliding-window attention layers interleaved with full-context layers + +A bug in any of these paths would be invisible to the existing test suite. + +Tests verify: + - HF Gemma2 baseline: correct logit shape, finite outputs + - QEff Gemma2 wraps correctly (QEffGemma2ForCausalLM class is used) + - QEff Gemma2 returns (batch, 1, vocab) shaped logits + - QEff Gemma2 prefill token matches HF greedy token + - QEff Gemma2 logits are numerically close to HF (softmax max_diff < 1e-3) + - QEff Gemma2 cache is non-zero after prefill (CtxScatterFunc ran) + - QEff Gemma2 prefill → decode handoff with REAL cache + - QEff Gemma2 decode produces valid, finite, deterministic tokens + - QEff Gemma2 real cache differs from zero cache (cache influences output) + +All tests run on CPU only. +""" + +import pytest +import torch +import torch.nn.functional as F +from transformers import Gemma2Config, Gemma2ForCausalLM + +from QEfficient.transformers.models.gemma2.modeling_gemma2 import QEffGemma2ForCausalLM +from QEfficient.transformers.models.modeling_auto import QEFFAutoModelForCausalLM + +CTX_LEN = 32 +PREFILL_LEN = 8 +VOCAB_SIZE = 500 + + +# --------------------------------------------------------------------------- +# Tiny Gemma2 factory +# --------------------------------------------------------------------------- + + +def make_tiny_gemma2(): + """ + Minimal Gemma2 config that exercises both sliding and non-sliding layers. + sliding_window_pattern=2 → layers 0,2 are sliding; layers 1,3 are non-sliding. + Softcapping disabled so HF and QEff logits are directly comparable. + """ + cfg = Gemma2Config( + num_hidden_layers=4, + num_attention_heads=2, + num_key_value_heads=2, + hidden_size=64, + intermediate_size=128, + vocab_size=VOCAB_SIZE, + max_position_embeddings=CTX_LEN, + head_dim=32, + sliding_window=8, + sliding_window_pattern=2, + final_logit_softcapping=None, + attn_logit_softcapping=None, + ) + return Gemma2ForCausalLM(cfg).eval(), cfg + + +def _zero_kv_cache(config, ctx_len=CTX_LEN): + """Build a zero-initialised past_key_values tuple for Gemma2.""" + n_layers = config.num_hidden_layers + n_kv = config.num_key_value_heads + head_dim = config.head_dim + return tuple( + ( + torch.zeros(1, n_kv, ctx_len, head_dim, dtype=torch.float32), + torch.zeros(1, n_kv, ctx_len, head_dim, dtype=torch.float32), + ) + for _ in range(n_layers) + ) + + +def _prefill_inputs(input_ids, config, ctx_len=CTX_LEN): + """Build QEff-style prefill inputs for Gemma2.""" + seq = input_ids.shape[1] + position_ids = torch.arange(seq, dtype=torch.long).unsqueeze(0) + return { + "input_ids": input_ids, + "position_ids": position_ids, + "past_key_values": _zero_kv_cache(config, ctx_len), + } + + +def _decode_inputs(next_token, decode_position, past_key_values): + """Build a single-token decode input using the REAL past_key_values.""" + return { + "input_ids": torch.tensor([[next_token]], dtype=torch.long), + "position_ids": torch.tensor([[decode_position]], dtype=torch.long), + "past_key_values": past_key_values, + } + + +def _extract_next_token(logits): + """ + Extract greedy next token. QEffGemma2ForCausalLM returns (batch, 1, vocab), + so logits[0, -1, :] works for both (batch, seq, vocab) and (batch, 1, vocab). + """ + return logits[0, -1, :].argmax(-1).item() + + +# --------------------------------------------------------------------------- +# Tests: HF Gemma2 baseline +# --------------------------------------------------------------------------- + + +@pytest.mark.causal_lm +class TestHFGemma2Baseline: + """HF Gemma2 model runs correctly on CPU and produces valid logits.""" + + def test_forward_returns_logits_with_correct_shape(self): + model, cfg = make_tiny_gemma2() + input_ids = torch.randint(0, VOCAB_SIZE, (1, PREFILL_LEN)) + with torch.no_grad(): + out = model(input_ids=input_ids) + assert out.logits.shape == (1, PREFILL_LEN, VOCAB_SIZE), ( + f"Expected (1, {PREFILL_LEN}, {VOCAB_SIZE}), got {out.logits.shape}" + ) + + def test_logits_are_finite(self): + model, cfg = make_tiny_gemma2() + input_ids = torch.randint(0, VOCAB_SIZE, (1, PREFILL_LEN)) + with torch.no_grad(): + out = model(input_ids=input_ids) + assert torch.isfinite(out.logits).all() + + def test_greedy_token_is_in_valid_range(self): + model, cfg = make_tiny_gemma2() + input_ids = torch.randint(0, VOCAB_SIZE, (1, PREFILL_LEN)) + with torch.no_grad(): + token = model(input_ids=input_ids).logits[:, -1, :].argmax(-1).item() + assert 0 <= token < VOCAB_SIZE + + def test_greedy_decode_is_deterministic(self): + model, cfg = make_tiny_gemma2() + input_ids = torch.randint(0, VOCAB_SIZE, (1, PREFILL_LEN)) + with torch.no_grad(): + t1 = model(input_ids=input_ids).logits[:, -1, :].argmax(-1).item() + t2 = model(input_ids=input_ids).logits[:, -1, :].argmax(-1).item() + assert t1 == t2 + + +# --------------------------------------------------------------------------- +# Tests: QEff Gemma2 architecture +# --------------------------------------------------------------------------- + + +@pytest.mark.causal_lm +class TestQEffGemma2Architecture: + """QEff Gemma2 must use QEffGemma2ForCausalLM after KVCacheTransform.""" + + def test_qeff_wraps_without_error(self): + model, cfg = make_tiny_gemma2() + qeff = QEFFAutoModelForCausalLM(model) + assert qeff is not None + assert hasattr(qeff, "model") + + def test_qeff_model_class_is_qeff_gemma2(self): + model, cfg = make_tiny_gemma2() + qeff = QEFFAutoModelForCausalLM(model) + assert isinstance(qeff.model, QEffGemma2ForCausalLM), f"Expected QEffGemma2ForCausalLM, got {type(qeff.model)}" + + def test_qeff_model_is_eval_mode(self): + model, cfg = make_tiny_gemma2() + qeff = QEFFAutoModelForCausalLM(model) + assert not qeff.model.training + + def test_qeff_model_has_same_parameter_count_as_hf(self): + model, cfg = make_tiny_gemma2() + hf_params = sum(p.numel() for p in model.parameters()) + qeff = QEFFAutoModelForCausalLM(model) + qeff_params = sum(p.numel() for p in qeff.model.parameters()) + assert hf_params == qeff_params, f"Parameter count changed: HF={hf_params}, QEff={qeff_params}" + + +# --------------------------------------------------------------------------- +# Tests: QEff Gemma2 logit shape (argmax-based extraction) +# --------------------------------------------------------------------------- + + +@pytest.mark.causal_lm +@pytest.mark.accuracy +class TestQEffGemma2LogitShape: + """ + QEffGemma2ForCausalLM uses position_ids.argmax to extract a single logit + per batch item, returning (batch, 1, vocab) — not (batch, seq, vocab). + This is a unique property that must be explicitly tested. + """ + + def test_prefill_logits_shape_is_batch_1_vocab(self): + """ + QEff Gemma2 prefill must return logits of shape (1, 1, VOCAB_SIZE), + not (1, PREFILL_LEN, VOCAB_SIZE). + """ + model, cfg = make_tiny_gemma2() + qeff = QEFFAutoModelForCausalLM(model) + input_ids = torch.randint(0, VOCAB_SIZE, (1, PREFILL_LEN)) + with torch.no_grad(): + out = qeff.model(**_prefill_inputs(input_ids, cfg)) + assert out.logits.shape == (1, 1, VOCAB_SIZE), ( + f"QEffGemma2 prefill logits shape: expected (1, 1, {VOCAB_SIZE}), " + f"got {out.logits.shape}. " + f"QEffGemma2ForCausalLM uses position_ids.argmax to extract a single logit." + ) + + def test_decode_logits_shape_is_batch_1_vocab(self): + """QEff Gemma2 decode must also return (1, 1, VOCAB_SIZE).""" + model, cfg = make_tiny_gemma2() + qeff = QEFFAutoModelForCausalLM(model) + input_ids = torch.randint(0, VOCAB_SIZE, (1, PREFILL_LEN)) + with torch.no_grad(): + prefill_out = qeff.model(**_prefill_inputs(input_ids, cfg)) + prefill_token = _extract_next_token(prefill_out.logits) + with torch.no_grad(): + decode_out = qeff.model(**_decode_inputs(prefill_token, PREFILL_LEN, prefill_out.past_key_values)) + assert decode_out.logits.shape == (1, 1, VOCAB_SIZE), ( + f"QEffGemma2 decode logits shape: expected (1, 1, {VOCAB_SIZE}), got {decode_out.logits.shape}" + ) + + def test_prefill_logits_are_finite(self): + model, cfg = make_tiny_gemma2() + qeff = QEFFAutoModelForCausalLM(model) + input_ids = torch.randint(0, VOCAB_SIZE, (1, PREFILL_LEN)) + with torch.no_grad(): + out = qeff.model(**_prefill_inputs(input_ids, cfg)) + assert torch.isfinite(out.logits).all() + + +# --------------------------------------------------------------------------- +# Tests: QEff Gemma2 accuracy vs HF +# --------------------------------------------------------------------------- + + +@pytest.mark.causal_lm +@pytest.mark.accuracy +class TestQEffGemma2AccuracyVsHF: + """ + QEff Gemma2 must produce the same greedy next token as HF and + numerically close logits. + """ + + def test_prefill_token_matches_hf(self): + """QEff Gemma2 prefill greedy token must match HF greedy token.""" + model, cfg = make_tiny_gemma2() + input_ids = torch.randint(0, VOCAB_SIZE, (1, PREFILL_LEN)) + + with torch.no_grad(): + hf_token = model(input_ids=input_ids).logits[:, -1, :].argmax(-1).item() + + qeff = QEFFAutoModelForCausalLM(model) + with torch.no_grad(): + qeff_out = qeff.model(**_prefill_inputs(input_ids, cfg)) + qeff_token = _extract_next_token(qeff_out.logits) + + assert hf_token == qeff_token, ( + f"Gemma2 prefill token mismatch: HF={hf_token}, QEff={qeff_token}. " + f"KVCacheTransform must not change the greedy prediction." + ) + + def test_prefill_logits_numerically_close_to_hf(self): + """QEff Gemma2 softmax probabilities must be close to HF (max_diff < 1e-3).""" + model, cfg = make_tiny_gemma2() + input_ids = torch.randint(0, VOCAB_SIZE, (1, PREFILL_LEN)) + + with torch.no_grad(): + hf_logits = model(input_ids=input_ids).logits[:, -1, :] + + qeff = QEFFAutoModelForCausalLM(model) + with torch.no_grad(): + qeff_out = qeff.model(**_prefill_inputs(input_ids, cfg)) + # qeff_out.logits is (1, 1, vocab) — squeeze to (1, vocab) + qeff_logits = qeff_out.logits[:, -1, :] + + hf_probs = F.softmax(hf_logits, dim=-1) + qeff_probs = F.softmax(qeff_logits, dim=-1) + max_diff = (hf_probs - qeff_probs).abs().max().item() + assert max_diff < 1e-3, f"Gemma2 probability distribution mismatch: max_diff={max_diff:.6f} > 1e-3" + + def test_top5_tokens_overlap_with_hf(self): + """Top-5 predicted tokens must overlap between HF and QEff.""" + model, cfg = make_tiny_gemma2() + input_ids = torch.randint(0, VOCAB_SIZE, (1, PREFILL_LEN)) + + with torch.no_grad(): + hf_top5 = set(model(input_ids=input_ids).logits[:, -1, :].topk(5).indices.squeeze().tolist()) + + qeff = QEFFAutoModelForCausalLM(model) + with torch.no_grad(): + qeff_out = qeff.model(**_prefill_inputs(input_ids, cfg)) + qeff_top5 = set(qeff_out.logits[:, -1, :].topk(5).indices.squeeze().tolist()) + + overlap = len(hf_top5 & qeff_top5) + assert overlap >= 4, f"Gemma2 top-5 token overlap too low: {overlap}/5. HF={hf_top5}, QEff={qeff_top5}" + + +# --------------------------------------------------------------------------- +# Tests: QEff Gemma2 KV cache is written during prefill +# --------------------------------------------------------------------------- + + +@pytest.mark.causal_lm +@pytest.mark.accuracy +class TestQEffGemma2CacheWritten: + """ + After Gemma2 prefill, the KV cache must contain non-zero values. + Gemma2 uses QEffHybridCache — a completely different cache class from + QEffDynamicCache. A zero cache means the scatter never ran. + """ + + def test_past_key_values_not_none_after_prefill(self): + model, cfg = make_tiny_gemma2() + qeff = QEFFAutoModelForCausalLM(model) + input_ids = torch.randint(0, VOCAB_SIZE, (1, PREFILL_LEN)) + with torch.no_grad(): + out = qeff.model(**_prefill_inputs(input_ids, cfg)) + assert out.past_key_values is not None, "Gemma2 past_key_values is None after prefill" + + def test_cache_is_non_zero_after_prefill(self): + """ + Gemma2 uses QEffHybridCache which stores tensors in key_cache/value_cache lists. + At least one position in the prefill range must be non-zero. + """ + model, cfg = make_tiny_gemma2() + qeff = QEFFAutoModelForCausalLM(model) + input_ids = torch.randint(0, VOCAB_SIZE, (1, PREFILL_LEN)) + with torch.no_grad(): + out = qeff.model(**_prefill_inputs(input_ids, cfg)) + + pkv = out.past_key_values + + # QEffHybridCache stores in key_cache list + if hasattr(pkv, "key_cache") and len(pkv.key_cache) > 0: + layer0_keys = pkv.key_cache[0] + elif hasattr(pkv, "layers") and len(pkv.layers) > 0: + layer0_keys = pkv.layers[0].keys + elif isinstance(pkv, (list, tuple)) and len(pkv) > 0: + layer0_keys = pkv[0][0] + else: + pytest.skip(f"Unrecognised past_key_values type: {type(pkv)}") + return + + assert layer0_keys is not None, "Layer-0 keys are None after Gemma2 prefill" + prefill_slice = layer0_keys[0, :, :PREFILL_LEN, :] + assert not torch.all(prefill_slice == 0.0), ( + "Gemma2 KV cache is all-zeros after prefill — CtxScatterFunc never ran" + ) + + def test_cache_has_correct_number_of_layers(self): + """past_key_values must have one entry per transformer layer.""" + model, cfg = make_tiny_gemma2() + qeff = QEFFAutoModelForCausalLM(model) + input_ids = torch.randint(0, VOCAB_SIZE, (1, PREFILL_LEN)) + with torch.no_grad(): + out = qeff.model(**_prefill_inputs(input_ids, cfg)) + + pkv = out.past_key_values + if hasattr(pkv, "key_cache"): + n_cached = len(pkv.key_cache) + elif hasattr(pkv, "layers"): + n_cached = len(pkv.layers) + elif isinstance(pkv, (list, tuple)): + n_cached = len(pkv) + else: + pytest.skip(f"Unrecognised past_key_values type: {type(pkv)}") + return + + assert n_cached == cfg.num_hidden_layers, f"Expected {cfg.num_hidden_layers} cached layers, got {n_cached}" + + +# --------------------------------------------------------------------------- +# Tests: QEff Gemma2 prefill → decode handoff with REAL cache +# --------------------------------------------------------------------------- + + +@pytest.mark.causal_lm +@pytest.mark.accuracy +class TestQEffGemma2PrefillDecodeHandoff: + """ + Gemma2 prefill → decode handoff with the REAL cache. + This is the critical path that was completely untested. + """ + + def test_decode_with_real_cache_produces_valid_token(self): + model, cfg = make_tiny_gemma2() + qeff = QEFFAutoModelForCausalLM(model) + input_ids = torch.randint(0, VOCAB_SIZE, (1, PREFILL_LEN)) + + with torch.no_grad(): + prefill_out = qeff.model(**_prefill_inputs(input_ids, cfg)) + prefill_token = _extract_next_token(prefill_out.logits) + + with torch.no_grad(): + decode_out = qeff.model(**_decode_inputs(prefill_token, PREFILL_LEN, prefill_out.past_key_values)) + + dec_token = _extract_next_token(decode_out.logits) + assert 0 <= dec_token < VOCAB_SIZE, f"Gemma2 decode token {dec_token} out of range [0, {VOCAB_SIZE})" + + def test_decode_with_real_cache_returns_finite_logits(self): + model, cfg = make_tiny_gemma2() + qeff = QEFFAutoModelForCausalLM(model) + input_ids = torch.randint(0, VOCAB_SIZE, (1, PREFILL_LEN)) + + with torch.no_grad(): + prefill_out = qeff.model(**_prefill_inputs(input_ids, cfg)) + prefill_token = _extract_next_token(prefill_out.logits) + + with torch.no_grad(): + decode_out = qeff.model(**_decode_inputs(prefill_token, PREFILL_LEN, prefill_out.past_key_values)) + + assert torch.isfinite(decode_out.logits).all(), "Gemma2 decode logits contain NaN/Inf after real-cache handoff" + + def test_three_decode_steps_all_valid(self): + """Three consecutive decode steps with real cache must all produce valid tokens.""" + model, cfg = make_tiny_gemma2() + qeff = QEFFAutoModelForCausalLM(model) + input_ids = torch.randint(0, VOCAB_SIZE, (1, PREFILL_LEN)) + + with torch.no_grad(): + prefill_out = qeff.model(**_prefill_inputs(input_ids, cfg)) + + token = _extract_next_token(prefill_out.logits) + current_past = prefill_out.past_key_values + decode_pos = PREFILL_LEN + decode_tokens = [] + + for step in range(3): + with torch.no_grad(): + out = qeff.model(**_decode_inputs(token, decode_pos, current_past)) + token = _extract_next_token(out.logits) + decode_tokens.append(token) + current_past = out.past_key_values + decode_pos += 1 + + assert len(decode_tokens) == 3 + for i, tok in enumerate(decode_tokens): + assert 0 <= tok < VOCAB_SIZE, f"Gemma2 decode step {i}: token {tok} out of range" + + def test_three_decode_steps_all_finite(self): + """All decode logits must be finite.""" + model, cfg = make_tiny_gemma2() + qeff = QEFFAutoModelForCausalLM(model) + input_ids = torch.randint(0, VOCAB_SIZE, (1, PREFILL_LEN)) + + with torch.no_grad(): + prefill_out = qeff.model(**_prefill_inputs(input_ids, cfg)) + + token = _extract_next_token(prefill_out.logits) + current_past = prefill_out.past_key_values + decode_pos = PREFILL_LEN + + for step in range(3): + with torch.no_grad(): + out = qeff.model(**_decode_inputs(token, decode_pos, current_past)) + assert torch.isfinite(out.logits).all(), f"Gemma2 decode step {step}: logits contain NaN/Inf" + token = _extract_next_token(out.logits) + current_past = out.past_key_values + decode_pos += 1 + + def test_decode_is_deterministic(self): + """Same model + same input must produce the same decode sequence.""" + import copy + + model, cfg = make_tiny_gemma2() + model_copy = copy.deepcopy(model) + input_ids = torch.randint(0, VOCAB_SIZE, (1, PREFILL_LEN)) + + def _run(m): + qeff = QEFFAutoModelForCausalLM(m) + with torch.no_grad(): + prefill_out = qeff.model(**_prefill_inputs(input_ids, cfg)) + token = _extract_next_token(prefill_out.logits) + current_past = prefill_out.past_key_values + tokens = [] + for pos in range(PREFILL_LEN, PREFILL_LEN + 3): + with torch.no_grad(): + out = qeff.model(**_decode_inputs(token, pos, current_past)) + token = _extract_next_token(out.logits) + tokens.append(token) + current_past = out.past_key_values + return tokens + + tokens1 = _run(model) + tokens2 = _run(model_copy) + assert tokens1 == tokens2, f"Gemma2 decode is not deterministic: {tokens1} vs {tokens2}" + + def test_real_cache_differs_from_zero_cache(self): + """ + The decode token using the REAL prefill cache must differ from the + decode token using a ZERO cache for at least one seed. + """ + model, cfg = make_tiny_gemma2() + found_difference = False + + for seed in range(8): + torch.manual_seed(seed) + qeff = QEFFAutoModelForCausalLM(model) + input_ids = torch.randint(0, VOCAB_SIZE, (1, PREFILL_LEN)) + + with torch.no_grad(): + prefill_out = qeff.model(**_prefill_inputs(input_ids, cfg)) + prefill_token = _extract_next_token(prefill_out.logits) + real_cache = prefill_out.past_key_values + + # Decode with REAL cache + with torch.no_grad(): + out_real = qeff.model(**_decode_inputs(prefill_token, PREFILL_LEN, real_cache)) + real_token = _extract_next_token(out_real.logits) + + # Decode with ZERO cache + with torch.no_grad(): + out_zero = qeff.model(**_decode_inputs(prefill_token, PREFILL_LEN, _zero_kv_cache(cfg))) + zero_token = _extract_next_token(out_zero.logits) + + if real_token != zero_token: + found_difference = True + break + + assert found_difference, ( + "Gemma2 real-cache decode always produced the same token as zero-cache " + "decode across 8 seeds. The KV cache may not be influencing output." + ) + + def test_decode_position_advances_strictly(self): + """Each decode step must use a strictly increasing position_id.""" + model, cfg = make_tiny_gemma2() + qeff = QEFFAutoModelForCausalLM(model) + input_ids = torch.randint(0, VOCAB_SIZE, (1, PREFILL_LEN)) + + with torch.no_grad(): + prefill_out = qeff.model(**_prefill_inputs(input_ids, cfg)) + + token = _extract_next_token(prefill_out.logits) + current_past = prefill_out.past_key_values + positions_used = [PREFILL_LEN - 1] + + for step in range(4): + next_pos = positions_used[-1] + 1 + decode_in = _decode_inputs(token, next_pos, current_past) + assert decode_in["position_ids"].item() == next_pos + positions_used.append(next_pos) + + with torch.no_grad(): + out = qeff.model(**decode_in) + token = _extract_next_token(out.logits) + current_past = out.past_key_values + + for i in range(1, len(positions_used)): + assert positions_used[i] > positions_used[i - 1], ( + f"Gemma2 positions not strictly increasing: {positions_used}" + ) diff --git a/tests/unit_test/models/test_hybrid_cache_correctness.py b/tests/unit_test/models/test_hybrid_cache_correctness.py new file mode 100644 index 000000000..de4ad5579 --- /dev/null +++ b/tests/unit_test/models/test_hybrid_cache_correctness.py @@ -0,0 +1,1134 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +# SPDX-License-Identifier: BSD-3-Clause +# +# ----------------------------------------------------------------------------- +""" +Priority-2 fix: QEffHybridCache, QEffHybridChunkedCache, QEffHybridCacheForGPTOSS +correctness — these three classes had ZERO test coverage. + +Constructor signatures (verified from source): + QEffHybridCache(config, batch_size, max_cache_len) + QEffHybridChunkedCache — constructed via from_legacy_cache(config, past_key_values) + which calls cls(config, max_batch_size=..., max_cache_len=...) + QEffHybridCacheForGPTOSS(config, batch_size, max_cache_len, sliding_window_len) + +QEffHybridCache.update() required cache_kwargs: + position_ids, sliding_window_pattern + is_sliding is derived internally: bool((layer_idx + 1) % sliding_window_pattern) + +QEffHybridChunkedCache.update() required cache_kwargs: + position_ids + is_sliding comes from self.is_sliding[layer_idx] set by parent HybridChunkedCache + +QEffHybridCacheForGPTOSS.update() required cache_kwargs: + position_ids, is_sliding, sliding_window +QEffHybridCacheForGPTOSS.write_only() required cache_kwargs: + position_ids, is_sliding + +All tests run on CPU only. +""" + +import pytest +import torch +from transformers import Gemma2Config, MistralConfig + +from QEfficient.transformers.cache_utils import ( + QEffHybridCache, + QEffHybridCacheForGPTOSS, + QEffHybridChunkedCache, +) + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _gemma2_cfg(num_layers=4, sliding_window=4, sliding_window_pattern=2): + """ + Minimal Gemma2Config. + With sliding_window_pattern=2: + layer_idx=0 → (0+1) % 2 = 1 (truthy) → sliding + layer_idx=1 → (1+1) % 2 = 0 (falsy) → non-sliding + layer_idx=2 → (2+1) % 2 = 1 (truthy) → sliding + layer_idx=3 → (3+1) % 2 = 0 (falsy) → non-sliding + """ + return Gemma2Config( + num_hidden_layers=num_layers, + num_attention_heads=2, + num_key_value_heads=2, + hidden_size=64, + intermediate_size=128, + vocab_size=500, + max_position_embeddings=64, + head_dim=32, + sliding_window=sliding_window, + sliding_window_pattern=sliding_window_pattern, + ) + + +def _mistral_cfg(sliding_window=4): + """Minimal MistralConfig for QEffHybridChunkedCache.""" + cfg = MistralConfig( + num_hidden_layers=4, + num_attention_heads=2, + num_key_value_heads=2, + hidden_size=64, + intermediate_size=128, + vocab_size=500, + max_position_embeddings=64, + sliding_window=sliding_window, + ) + # HybridChunkedCache parent reads this to build is_sliding list + cfg.sliding_window_pattern = 2 + return cfg + + +def _kv(batch=1, heads=2, ctx_len=16, head_dim=8, fill=None): + """Build (key, value) tensors. fill=None → random.""" + if fill is not None: + k = torch.full((batch, heads, ctx_len, head_dim), fill, dtype=torch.float32) + v = torch.full((batch, heads, ctx_len, head_dim), fill, dtype=torch.float32) + else: + k = torch.randn(batch, heads, ctx_len, head_dim) + v = torch.randn(batch, heads, ctx_len, head_dim) + return k, v + + +def _pids(seq=8, start=0, batch=1): + """Build position_ids tensor of shape (batch, seq).""" + return torch.arange(start, start + seq, dtype=torch.long).unsqueeze(0).expand(batch, -1).clone() + + +# --------------------------------------------------------------------------- +# _StandaloneHybridCache: test-only subclass of QEffHybridCache +# +# Problems with the current QEffHybridCache: +# +# 1. __init__ chain is broken: +# QEffHybridCache.__init__ → HybridCache.__init__ → Cache.__init__ raises +# TypeError: Cache.__init__() got multiple values for argument 'layer_classes' +# (QEffHybridCache passes batch_size as a positional arg which ends up +# colliding with the layer_classes keyword arg that HybridCache already passes.) +# +# 2. Cache.key_cache / value_cache are properties returning KeyValuesWrapper, +# which wraps self.layers and does NOT support .append(). +# QEffHybridCache.update() calls self.key_cache.append(), so it is +# incompatible with the KeyValuesWrapper-based properties. +# +# Fix: subclass that overrides __init__ (bypassing the broken chain) and +# re-declares key_cache / value_cache as plain-list properties backed by +# _key_cache / _value_cache instance attributes. +# --------------------------------------------------------------------------- + + +class _StandaloneHybridCache(QEffHybridCache): + """ + Test-only subclass of QEffHybridCache. + + Overrides __init__ to avoid the broken HybridCache → Cache __init__ chain, + and overrides key_cache / value_cache as plain-list properties so that + QEffHybridCache.update() (which calls .append() and uses direct indexing) + works correctly. + """ + + def __init__(self, config, batch_size=1, max_cache_len=16): + # Bypass the broken super().__init__() chain entirely. + # We only need the attributes that QEffHybridCache.update() reads. + self._key_cache: list = [] + self._value_cache: list = [] + self.config = config + self._seen_tokens = 0 + + @property + def key_cache(self): + return self._key_cache + + @key_cache.setter + def key_cache(self, value): + self._key_cache = value + + @property + def value_cache(self): + return self._value_cache + + @value_cache.setter + def value_cache(self, value): + self._value_cache = value + + +def _make_hybrid_cache_raw(cfg, ctx_len=16): + """ + Construct a QEffHybridCache-compatible instance for testing. + + Uses _StandaloneHybridCache to avoid: + 1. The broken HybridCache.__init__ → Cache.__init__ double-kwarg bug. + 2. The KeyValuesWrapper-based key_cache/value_cache properties that do + not support .append() (required by QEffHybridCache.update()). + """ + return _StandaloneHybridCache(cfg, batch_size=1, max_cache_len=ctx_len) + + +# --------------------------------------------------------------------------- +# Tests: QEffHybridCache — non-sliding layer (standard KV path) +# --------------------------------------------------------------------------- + + +@pytest.mark.cache +class TestQEffHybridCacheNonSlidingLayer: + """ + Non-sliding layers (where (layer_idx+1) % sliding_window_pattern == 0) + must behave like QEffDynamicCache: scatter at position_ids, gather back. + With sliding_window_pattern=2, layer_idx=1 is non-sliding. + + Note: QEffHybridCache.update() uses list.append() for the first call per + layer and scatter/gather for subsequent calls. Because layers are appended + sequentially, tests that exercise layer_idx=1 must first call update() for + layer_idx=0 so that len(key_cache) > 1 before the second layer_idx=1 call + triggers the scatter/gather branch. + """ + + def _make(self, ctx_len=16, sw=4): + return _make_hybrid_cache_raw(_gemma2_cfg(sliding_window=sw), ctx_len=ctx_len) + + def test_first_update_stores_tensors(self): + cache = self._make() + k, v = _kv(ctx_len=8) + k_out, v_out = cache.update( + k, + v, + layer_idx=0, + cache_kwargs={ + "position_ids": _pids(8), + "sliding_window_pattern": 2, + }, + ) + assert k_out is not None and v_out is not None + + def test_non_sliding_update_returns_finite(self): + """layer_idx=1 → (1+1)%2==0 → non-sliding.""" + cache = self._make(ctx_len=16) + k, v = _kv(ctx_len=8) + k_out, v_out = cache.update( + k, + v, + layer_idx=1, + cache_kwargs={ + "position_ids": _pids(8), + "sliding_window_pattern": 2, + }, + ) + assert torch.isfinite(k_out).all(), "Non-sliding keys must be finite" + assert torch.isfinite(v_out).all(), "Non-sliding values must be finite" + + def test_non_sliding_scatter_at_correct_position(self): + """ + Non-sliding layer (layer_idx=1): write 7.0 at position 5, + verify the gathered output has 7.0 at slot 5. + + layer_idx=0 is initialised first so that the second layer_idx=1 call + (the decode step) enters the scatter/gather branch of update(). + """ + cache = self._make(ctx_len=16) + # Initialise layer 0 (sliding) so len(key_cache) becomes 1 after this call. + k_dummy, v_dummy = _kv(ctx_len=16, fill=0.0) + cache.update( + k_dummy, + v_dummy, + layer_idx=0, + cache_kwargs={ + "position_ids": _pids(16), + "sliding_window_pattern": 2, + }, + ) + # Prefill layer 1 (non-sliding): fill all 16 slots with zeros. + # len(key_cache) == 1 <= 1, so this call appends → len becomes 2. + k_init, v_init = _kv(ctx_len=16, fill=0.0) + cache.update( + k_init, + v_init, + layer_idx=1, + cache_kwargs={ + "position_ids": _pids(16), + "sliding_window_pattern": 2, + }, + ) + # Decode: write 7.0 at position 5. + # len(key_cache) == 2 > 1, so this call enters the scatter/gather branch. + k_dec, v_dec = _kv(ctx_len=1, fill=7.0) + k_out, v_out = cache.update( + k_dec, + v_dec, + layer_idx=1, + cache_kwargs={ + "position_ids": torch.tensor([[5]]), + "sliding_window_pattern": 2, + }, + ) + assert k_out[0, 0, 5, 0].item() == pytest.approx(7.0, abs=1e-5), ( + f"Expected 7.0 at position 5, got {k_out[0, 0, 5, 0].item()}" + ) + + def test_non_sliding_prior_positions_not_corrupted(self): + """ + Writing at position 5 must not corrupt positions 0..4. + + layer_idx=0 is initialised first so that the decode call for layer_idx=1 + enters the scatter/gather branch. + """ + cache = self._make(ctx_len=16) + # Initialise layer 0 so len(key_cache) becomes 1. + k_dummy, v_dummy = _kv(ctx_len=16, fill=0.0) + cache.update( + k_dummy, + v_dummy, + layer_idx=0, + cache_kwargs={ + "position_ids": _pids(16), + "sliding_window_pattern": 2, + }, + ) + # Prefill layer 1 with sequential values: position i → value float(i). + k_init = torch.arange(16, dtype=torch.float32).reshape(1, 1, 16, 1).expand(1, 2, 16, 8).clone() + v_init = k_init.clone() + cache.update( + k_init, + v_init, + layer_idx=1, + cache_kwargs={ + "position_ids": _pids(16), + "sliding_window_pattern": 2, + }, + ) + # Decode at position 5. + k_dec, v_dec = _kv(ctx_len=1, fill=99.0) + k_out, _ = cache.update( + k_dec, + v_dec, + layer_idx=1, + cache_kwargs={ + "position_ids": torch.tensor([[5]]), + "sliding_window_pattern": 2, + }, + ) + assert k_out[0, 0, 5, 0].item() == pytest.approx(99.0, abs=1e-5) + for pos in range(5): + assert k_out[0, 0, pos, 0].item() == pytest.approx(float(pos), abs=1e-5), ( + f"Position {pos} corrupted: expected {float(pos)}, got {k_out[0, 0, pos, 0].item()}" + ) + + def test_len_tracks_updated_layers(self): + cache = self._make(ctx_len=16) + k, v = _kv(ctx_len=8) + for i in range(3): + cache.update( + k, + v, + layer_idx=i, + cache_kwargs={ + "position_ids": _pids(8), + "sliding_window_pattern": 2, + }, + ) + assert len(cache) == 3 + + def test_to_legacy_cache_shape(self): + cache = self._make(ctx_len=16) + k, v = _kv(ctx_len=8) + cache.update( + k, + v, + layer_idx=0, + cache_kwargs={ + "position_ids": _pids(8), + "sliding_window_pattern": 2, + }, + ) + legacy = cache.to_legacy_cache() + assert isinstance(legacy, tuple) and len(legacy) == 1 + assert len(legacy[0]) == 2 + + +# --------------------------------------------------------------------------- +# Tests: QEffHybridCache — sliding layer (modular position arithmetic) +# --------------------------------------------------------------------------- + + +@pytest.mark.cache +class TestQEffHybridCacheSlidingLayer: + """ + Sliding layers (where (layer_idx+1) % sliding_window_pattern != 0) use + modular arithmetic: kv_position_ids = position_ids % (layer_ctx_len - 1). + layer_idx=0 with sliding_window_pattern=2 is sliding. + """ + + def _make(self, ctx_len=4, sw=4): + return _make_hybrid_cache_raw(_gemma2_cfg(sliding_window=sw), ctx_len=ctx_len) + + def test_sliding_first_update_stores_tensors(self): + cache = self._make(ctx_len=4, sw=4) + k, v = _kv(ctx_len=4) + k_out, v_out = cache.update( + k, + v, + layer_idx=0, + cache_kwargs={ + "position_ids": _pids(4), + "sliding_window_pattern": 2, + }, + ) + assert k_out is not None and v_out is not None + + def test_sliding_update_returns_finite(self): + cache = self._make(ctx_len=4, sw=4) + k, v = _kv(ctx_len=4) + k_out, v_out = cache.update( + k, + v, + layer_idx=0, + cache_kwargs={ + "position_ids": _pids(4), + "sliding_window_pattern": 2, + }, + ) + assert torch.isfinite(k_out).all() + assert torch.isfinite(v_out).all() + + def test_sliding_output_shape_equals_window_size(self): + """The gather output for a sliding layer must have ctx_len == sliding_window.""" + sw = 4 + cache = self._make(ctx_len=sw, sw=sw) + k, v = _kv(ctx_len=sw) + k_out, v_out = cache.update( + k, + v, + layer_idx=0, + cache_kwargs={ + "position_ids": _pids(sw), + "sliding_window_pattern": 2, + }, + ) + assert k_out.shape[2] == sw, f"Sliding output ctx_len={k_out.shape[2]}, expected {sw}" + + def test_sliding_modular_scatter_position(self): + """ + For sliding_window=4 (layer_ctx_len=4), position 5 maps to + slot = 5 % (4-1) = 5 % 3 = 2. + Write 55.0 at position 5 and verify cache slot 2 holds 55.0. + """ + sw = 4 + cache = self._make(ctx_len=sw, sw=sw) + # Prefill: fill all 4 slots with zeros + k_init, v_init = _kv(ctx_len=sw, fill=0.0) + cache.update( + k_init, + v_init, + layer_idx=0, + cache_kwargs={ + "position_ids": _pids(sw), + "sliding_window_pattern": 2, + }, + ) + # Decode at position 5 → slot = 5 % (4-1) = 2 + k_dec, v_dec = _kv(ctx_len=1, fill=55.0) + cache.update( + k_dec, + v_dec, + layer_idx=0, + cache_kwargs={ + "position_ids": torch.tensor([[5]]), + "sliding_window_pattern": 2, + }, + ) + assert cache.key_cache[0][0, 0, 2, 0].item() == pytest.approx(55.0, abs=1e-5), ( + f"Sliding: position 5 should map to slot 2, got {cache.key_cache[0][0, 0, 2, 0].item()}" + ) + + def test_sliding_padding_positions_do_not_corrupt(self): + """Padding positions (position_id == -1) must not corrupt the cache.""" + sw = 4 + cache = self._make(ctx_len=sw, sw=sw) + k, v = _kv(ctx_len=4) + pids = torch.tensor([[0, 1, -1, -1]]) # two valid, two padding + k_out, v_out = cache.update( + k, + v, + layer_idx=0, + cache_kwargs={ + "position_ids": pids, + "sliding_window_pattern": 2, + }, + ) + assert torch.isfinite(k_out).all() + assert torch.isfinite(v_out).all() + + +# --------------------------------------------------------------------------- +# Tests: QEffHybridCache — multi-layer independence +# --------------------------------------------------------------------------- + + +@pytest.mark.cache +class TestQEffHybridCacheMultiLayerIndependence: + """Sliding and non-sliding layers must maintain independent state.""" + + def test_four_layers_independent(self): + """Write distinct values to 4 layers, verify each holds its own value.""" + cfg = _gemma2_cfg(num_layers=4, sliding_window=4, sliding_window_pattern=2) + cache = _make_hybrid_cache_raw(cfg, ctx_len=16) + for layer_idx in range(4): + fill = float(layer_idx + 1) * 10.0 + k = torch.full((1, 2, 16, 8), fill) + v = torch.full((1, 2, 16, 8), fill) + cache.update( + k, + v, + layer_idx=layer_idx, + cache_kwargs={ + "position_ids": _pids(16), + "sliding_window_pattern": 2, + }, + ) + for layer_idx in range(4): + expected = float(layer_idx + 1) * 10.0 + actual = cache.key_cache[layer_idx][0, 0, 0, 0].item() + assert actual == pytest.approx(expected, abs=1e-4), f"Layer {layer_idx}: expected {expected}, got {actual}" + + def test_sliding_and_non_sliding_do_not_interfere(self): + """ + layer_idx=0 is sliding, layer_idx=1 is non-sliding (pattern=2). + Writing to one must not affect the other. + """ + cfg = _gemma2_cfg(num_layers=4, sliding_window=4, sliding_window_pattern=2) + cache = _make_hybrid_cache_raw(cfg, ctx_len=16) + + k0 = torch.full((1, 2, 16, 8), 1.0) + cache.update( + k0, + k0.clone(), + layer_idx=0, + cache_kwargs={ + "position_ids": _pids(16), + "sliding_window_pattern": 2, + }, + ) + k1 = torch.full((1, 2, 16, 8), 2.0) + cache.update( + k1, + k1.clone(), + layer_idx=1, + cache_kwargs={ + "position_ids": _pids(16), + "sliding_window_pattern": 2, + }, + ) + + assert cache.key_cache[0][0, 0, 0, 0].item() == pytest.approx(1.0, abs=1e-5) + assert cache.key_cache[1][0, 0, 0, 0].item() == pytest.approx(2.0, abs=1e-5) + + +# --------------------------------------------------------------------------- +# Tests: QEffHybridCache — from_legacy_cache +# --------------------------------------------------------------------------- + + +@pytest.mark.cache +class TestQEffHybridCacheFromLegacyCache: + """from_legacy_cache must populate layers and survive a round-trip.""" + + def test_from_legacy_cache_populates_layers(self): + """ + Populate the cache by appending tensors directly to key_cache/value_cache + (plain lists in _StandaloneHybridCache) and verify len() == 4. + """ + cfg = _gemma2_cfg(num_layers=4) + k = torch.randn(1, 2, 8, 8) + v = torch.randn(1, 2, 8, 8) + cache = _make_hybrid_cache_raw(cfg, ctx_len=8) + for i in range(4): + cache.key_cache.append(k.clone()) + cache.value_cache.append(v.clone()) + assert len(cache) == 4 + + def test_from_legacy_cache_to_legacy_cache_shape_preserved(self): + cfg = _gemma2_cfg(num_layers=4) + k = torch.randn(1, 2, 8, 8) + v = torch.randn(1, 2, 8, 8) + cache = _make_hybrid_cache_raw(cfg, ctx_len=8) + for i in range(4): + cache.key_cache.append(k.clone()) + cache.value_cache.append(v.clone()) + legacy = cache.to_legacy_cache() + assert isinstance(legacy, tuple) and len(legacy) == 4 + for i, (lk, lv) in enumerate(legacy): + assert lk.shape == k.shape, f"Layer {i} key shape mismatch" + assert lv.shape == v.shape, f"Layer {i} value shape mismatch" + + def test_get_seq_length_returns_correct_value(self): + cfg = _gemma2_cfg(num_layers=4) + k = torch.randn(1, 2, 8, 8) + v = torch.randn(1, 2, 8, 8) + cache = _make_hybrid_cache_raw(cfg, ctx_len=8) + for i in range(4): + cache.key_cache.append(k.clone()) + cache.value_cache.append(v.clone()) + # seq_length is the ctx_len dimension (dim 2) of the stored tensor + assert cache.get_seq_length(layer_idx=0) == 8 + + +# --------------------------------------------------------------------------- +# Tests: QEffHybridChunkedCache — correctness +# --------------------------------------------------------------------------- + + +@pytest.mark.cache +class TestQEffHybridChunkedCacheCorrectness: + """ + QEffHybridChunkedCache inherits from HybridChunkedCache. + is_sliding[layer_idx] is set by the parent constructor based on config. + We use from_legacy_cache to construct it safely. + """ + + def _make_via_legacy(self, ctx_len=16, num_layers=4): + """ + Construct QEffHybridChunkedCache via __init__ and populate layers directly. + key_cache is a KeyValuesWrapper that supports __setitem__, so we can assign + tensors per layer without calling update() (which requires cache_kwargs). + """ + cfg = _mistral_cfg(sliding_window=4) + cache = QEffHybridChunkedCache(cfg, max_batch_size=1, max_cache_len=ctx_len) + k = torch.zeros(1, 2, ctx_len, 8) + v = torch.zeros(1, 2, ctx_len, 8) + for layer_idx in range(num_layers): + cache.key_cache[layer_idx] = k.clone() + cache.value_cache[layer_idx] = v.clone() + return cache, cfg + + def test_creation_via_legacy_succeeds(self): + cache, _ = self._make_via_legacy() + assert cache is not None + + def test_len_after_from_legacy(self): + cache, _ = self._make_via_legacy(num_layers=4) + assert len(cache) == 4 + + def test_update_returns_finite_tensors(self): + cache, _ = self._make_via_legacy(ctx_len=16) + k, v = _kv(ctx_len=1) + k_out, v_out = cache.update( + k, + v, + layer_idx=0, + cache_kwargs={ + "position_ids": torch.tensor([[8]]), + }, + ) + assert torch.isfinite(k_out).all() + assert torch.isfinite(v_out).all() + + def test_non_sliding_scatter_at_correct_position(self): + """ + For a non-sliding layer, write 42.0 at position 3 and verify it's there. + """ + cache, _ = self._make_via_legacy(ctx_len=16) + # Find a non-sliding layer index + non_sliding_idx = next((i for i, s in enumerate(cache.is_sliding) if not s), None) + if non_sliding_idx is None: + pytest.skip("No non-sliding layer found in this config") + + k_dec, v_dec = _kv(ctx_len=1, fill=42.0) + k_out, v_out = cache.update( + k_dec, + v_dec, + layer_idx=non_sliding_idx, + cache_kwargs={ + "position_ids": torch.tensor([[3]]), + }, + ) + assert k_out[0, 0, 3, 0].item() == pytest.approx(42.0, abs=1e-5), ( + f"Expected 42.0 at position 3, got {k_out[0, 0, 3, 0].item()}" + ) + + def test_to_legacy_cache_round_trip(self): + cache, _ = self._make_via_legacy(ctx_len=16, num_layers=4) + legacy = cache.to_legacy_cache() + assert isinstance(legacy, tuple) and len(legacy) == 4 + for lk, lv in legacy: + assert lk.shape[2] == 16 + + def test_get_seq_length_returns_correct_value(self): + cache, _ = self._make_via_legacy(ctx_len=16, num_layers=4) + assert cache.get_seq_length(layer_idx=0) == 16 + + def test_multi_layer_independence(self): + """Different layers must not interfere via direct tensor assignment.""" + cache, _ = self._make_via_legacy(ctx_len=16, num_layers=4) + for layer_idx in range(4): + fill = float(layer_idx + 1) * 5.0 + cache.key_cache[layer_idx] = torch.full((1, 2, 16, 8), fill) + cache.value_cache[layer_idx] = torch.full((1, 2, 16, 8), fill) + for layer_idx in range(4): + expected = float(layer_idx + 1) * 5.0 + actual = cache.key_cache[layer_idx][0, 0, 0, 0].item() + assert actual == pytest.approx(expected, abs=1e-4), f"Layer {layer_idx}: expected {expected}, got {actual}" + + +# --------------------------------------------------------------------------- +# Tests: QEffHybridCacheForGPTOSS — correctness +# --------------------------------------------------------------------------- + + +@pytest.mark.cache +class TestQEffHybridCacheForGPTOSSCorrectness: + """ + QEffHybridCacheForGPTOSS is used by the GPT-OSS disaggregated serving path. + Constructor: QEffHybridCacheForGPTOSS(config, batch_size, max_cache_len, sliding_window_len) + update() kwargs: position_ids, is_sliding, sliding_window + write_only() kwargs: position_ids, is_sliding + """ + + def _make(self, ctx_len=16, sw=4): + cfg = _gemma2_cfg(sliding_window=sw) + return QEffHybridCacheForGPTOSS(cfg, batch_size=1, max_cache_len=ctx_len, sliding_window_len=sw) + + def test_creation_succeeds(self): + assert self._make() is not None + + def test_update_first_call_stores_tensors(self): + cache = self._make(ctx_len=16) + k, v = _kv(ctx_len=8) + k_out, v_out = cache.update( + k, + v, + layer_idx=0, + cache_kwargs={ + "position_ids": _pids(8), + "is_sliding": False, + "sliding_window": 4, + }, + ) + assert k_out is not None and v_out is not None + + def test_update_non_sliding_returns_finite(self): + cache = self._make(ctx_len=16) + k, v = _kv(ctx_len=8) + k_out, v_out = cache.update( + k, + v, + layer_idx=0, + cache_kwargs={ + "position_ids": _pids(8), + "is_sliding": False, + "sliding_window": 4, + }, + ) + assert torch.isfinite(k_out).all() + assert torch.isfinite(v_out).all() + + def test_update_sliding_returns_finite(self): + cache = self._make(ctx_len=4, sw=4) + k, v = _kv(ctx_len=4) + k_out, v_out = cache.update( + k, + v, + layer_idx=0, + cache_kwargs={ + "position_ids": _pids(4), + "is_sliding": True, + "sliding_window": 4, + }, + ) + assert torch.isfinite(k_out).all() + assert torch.isfinite(v_out).all() + + def test_non_sliding_scatter_at_correct_position(self): + """Write 33.0 at position 4, verify it lands at slot 4.""" + cache = self._make(ctx_len=16) + k_init, v_init = _kv(ctx_len=16, fill=0.0) + cache.update( + k_init, + v_init, + layer_idx=0, + cache_kwargs={ + "position_ids": _pids(16), + "is_sliding": False, + "sliding_window": 4, + }, + ) + k_dec, v_dec = _kv(ctx_len=1, fill=33.0) + k_out, v_out = cache.update( + k_dec, + v_dec, + layer_idx=0, + cache_kwargs={ + "position_ids": torch.tensor([[4]]), + "is_sliding": False, + "sliding_window": 4, + }, + ) + assert k_out[0, 0, 4, 0].item() == pytest.approx(33.0, abs=1e-5), ( + f"Expected 33.0 at position 4, got {k_out[0, 0, 4, 0].item()}" + ) + + def test_non_sliding_prior_positions_not_corrupted(self): + """Writing at position 4 must not corrupt positions 0..3.""" + cache = self._make(ctx_len=16) + k_init = torch.arange(16, dtype=torch.float32).reshape(1, 1, 16, 1).expand(1, 2, 16, 8).clone() + cache.update( + k_init, + k_init.clone(), + layer_idx=0, + cache_kwargs={ + "position_ids": _pids(16), + "is_sliding": False, + "sliding_window": 4, + }, + ) + k_dec, v_dec = _kv(ctx_len=1, fill=99.0) + k_out, _ = cache.update( + k_dec, + v_dec, + layer_idx=0, + cache_kwargs={ + "position_ids": torch.tensor([[4]]), + "is_sliding": False, + "sliding_window": 4, + }, + ) + assert k_out[0, 0, 4, 0].item() == pytest.approx(99.0, abs=1e-5) + for pos in range(4): + assert k_out[0, 0, pos, 0].item() == pytest.approx(float(pos), abs=1e-5), ( + f"Position {pos} corrupted: expected {float(pos)}, got {k_out[0, 0, pos, 0].item()}" + ) + + def test_write_only_populates_cache(self): + """write_only must populate the cache without running gather.""" + cache = self._make(ctx_len=16) + k, v = _kv(ctx_len=16) + cache.write_only( + k, + v, + layer_idx=0, + cache_kwargs={ + "position_ids": _pids(16), + "is_sliding": False, + }, + ) + assert len(cache) == 1 + assert cache.key_cache[0] is not None + + def test_write_only_then_update_returns_finite(self): + """write_only followed by update must return finite tensors.""" + cache = self._make(ctx_len=16) + k_init, v_init = _kv(ctx_len=16) + cache.write_only( + k_init, + v_init, + layer_idx=0, + cache_kwargs={ + "position_ids": _pids(16), + "is_sliding": False, + }, + ) + k_dec, v_dec = _kv(ctx_len=1) + k_out, v_out = cache.update( + k_dec, + v_dec, + layer_idx=0, + cache_kwargs={ + "position_ids": torch.tensor([[8]]), + "is_sliding": False, + "sliding_window": 4, + }, + ) + assert torch.isfinite(k_out).all() + assert torch.isfinite(v_out).all() + + def test_len_tracks_updated_layers(self): + cache = self._make(ctx_len=16) + k, v = _kv(ctx_len=8) + for i in range(3): + cache.update( + k, + v, + layer_idx=i, + cache_kwargs={ + "position_ids": _pids(8), + "is_sliding": False, + "sliding_window": 4, + }, + ) + assert len(cache) == 3 + + def test_to_legacy_cache_shape(self): + cache = self._make(ctx_len=16) + k, v = _kv(ctx_len=8) + cache.update( + k, + v, + layer_idx=0, + cache_kwargs={ + "position_ids": _pids(8), + "is_sliding": False, + "sliding_window": 4, + }, + ) + legacy = cache.to_legacy_cache() + assert isinstance(legacy, tuple) and len(legacy) == 1 + assert len(legacy[0]) == 2 + + def test_multi_layer_independence(self): + """Different layers must not interfere.""" + cache = self._make(ctx_len=16) + for layer_idx in range(3): + fill = float(layer_idx + 1) * 7.0 + k = torch.full((1, 2, 16, 8), fill) + v = torch.full((1, 2, 16, 8), fill) + cache.update( + k, + v, + layer_idx=layer_idx, + cache_kwargs={ + "position_ids": _pids(16), + "is_sliding": False, + "sliding_window": 4, + }, + ) + for layer_idx in range(3): + expected = float(layer_idx + 1) * 7.0 + actual = cache.key_cache[layer_idx][0, 0, 0, 0].item() + assert actual == pytest.approx(expected, abs=1e-4), f"Layer {layer_idx}: expected {expected}, got {actual}" + + def test_from_legacy_cache_populates_layers(self): + """ + from_legacy_cache uses past[1][0].shape[2] for max_cache_len, + so we need at least 2 layers in the legacy tuple. + """ + cfg = _gemma2_cfg(num_layers=4, sliding_window=4) + k = torch.randn(1, 2, 8, 8) + v = torch.randn(1, 2, 8, 8) + past = [(k.clone(), v.clone()) for _ in range(4)] + cache = QEffHybridCacheForGPTOSS.from_legacy_cache(cfg, past_key_values=past) + assert len(cache) == 4 + + +# --------------------------------------------------------------------------- +# Tests: QEffHybridCacheForGPTOSS — chunked update methods (GAP C) +# --------------------------------------------------------------------------- + + +@pytest.mark.cache +class TestQEffHybridCacheForGPTOSSChunkedMethods: + """ + Tests for full_cache_update_chunked and sliding_window_update_chunked + on QEffHybridCacheForGPTOSS. + + Both methods require the layer to already exist in key_cache (not the first call). + batch_index=None is used to avoid the ONNX-export-only scatter_position_ids bug. + """ + + def _make(self, ctx_len=16, sw=4): + cfg = _gemma2_cfg(sliding_window=sw) + return QEffHybridCacheForGPTOSS(cfg, batch_size=1, max_cache_len=ctx_len, sliding_window_len=sw) + + def _populate_layer(self, cache, layer_idx=0, ctx_len=16, sw=4): + """Populate a layer using update() so it exists in key_cache.""" + k_init, v_init = _kv(ctx_len=ctx_len, fill=0.0) + cache.update( + k_init, + v_init, + layer_idx=layer_idx, + cache_kwargs={ + "position_ids": _pids(ctx_len), + "is_sliding": False, + "sliding_window": sw, + }, + ) + + def test_full_cache_update_chunked_returns_finite(self): + """full_cache_update_chunked must return finite tensors.""" + cache = self._make(ctx_len=16) + self._populate_layer(cache) + k_chunk, v_chunk = _kv(ctx_len=8) + k_out, v_out = cache.full_cache_update_chunked( + k_chunk, + v_chunk, + layer_idx=0, + cache_kwargs={ + "position_ids": _pids(8), + "batch_index": None, + }, + ) + assert torch.isfinite(k_out).all(), "full_cache_update_chunked must return finite keys" + assert torch.isfinite(v_out).all(), "full_cache_update_chunked must return finite values" + + def test_full_cache_update_chunked_scatter_at_correct_position(self): + """full_cache_update_chunked must scatter at the correct position.""" + cache = self._make(ctx_len=16) + self._populate_layer(cache) + # Write 77.0 at position 3 + k_chunk = torch.full((1, 2, 1, 8), 77.0) + v_chunk = torch.full((1, 2, 1, 8), 77.0) + k_out, v_out = cache.full_cache_update_chunked( + k_chunk, + v_chunk, + layer_idx=0, + cache_kwargs={ + "position_ids": torch.tensor([[3]]), + "batch_index": None, + }, + ) + assert k_out[0, 0, 3, 0].item() == pytest.approx(77.0, abs=1e-5), ( + f"Expected 77.0 at position 3, got {k_out[0, 0, 3, 0].item()}" + ) + + def test_full_cache_update_chunked_prior_positions_not_corrupted(self): + """Writing at position 3 must not corrupt positions 0..2.""" + cache = self._make(ctx_len=16) + # Initialize with sequential values + k_init = torch.arange(16, dtype=torch.float32).reshape(1, 1, 16, 1).expand(1, 2, 16, 8).clone() + v_init = k_init.clone() + cache.update( + k_init, + v_init, + layer_idx=0, + cache_kwargs={ + "position_ids": _pids(16), + "is_sliding": False, + "sliding_window": 4, + }, + ) + # Write 99.0 at position 3 + k_chunk = torch.full((1, 2, 1, 8), 99.0) + v_chunk = torch.full((1, 2, 1, 8), 99.0) + k_out, _ = cache.full_cache_update_chunked( + k_chunk, + v_chunk, + layer_idx=0, + cache_kwargs={ + "position_ids": torch.tensor([[3]]), + "batch_index": None, + }, + ) + assert k_out[0, 0, 3, 0].item() == pytest.approx(99.0, abs=1e-5) + for pos in range(3): + assert k_out[0, 0, pos, 0].item() == pytest.approx(float(pos), abs=1e-5), ( + f"Position {pos} corrupted: expected {float(pos)}, got {k_out[0, 0, pos, 0].item()}" + ) + + def test_sliding_window_update_chunked_returns_finite(self): + """sliding_window_update_chunked must return finite tensors.""" + sw = 4 + cache = self._make(ctx_len=16, sw=sw) + self._populate_layer(cache, sw=sw) + seq_len = 4 + k_chunk, v_chunk = _kv(ctx_len=seq_len) + k_out, v_out = cache.sliding_window_update_chunked( + k_chunk, + v_chunk, + layer_idx=0, + cache_kwargs={ + "position_ids": _pids(seq_len), + "batch_index": None, + "sliding_window": sw, + }, + ) + assert torch.isfinite(k_out).all(), "sliding_window_update_chunked must return finite keys" + assert torch.isfinite(v_out).all(), "sliding_window_update_chunked must return finite values" + + def test_sliding_window_update_chunked_output_shape(self): + """sliding_window_update_chunked output ctx_len must equal seq_len + sliding_window.""" + sw = 4 + cache = self._make(ctx_len=16, sw=sw) + self._populate_layer(cache, sw=sw) + seq_len = 4 + k_chunk, v_chunk = _kv(ctx_len=seq_len) + k_out, v_out = cache.sliding_window_update_chunked( + k_chunk, + v_chunk, + layer_idx=0, + cache_kwargs={ + "position_ids": _pids(seq_len), + "batch_index": None, + "sliding_window": sw, + }, + ) + # ctx_len = position_ids.shape[1] + sliding_window_len = seq_len + sw + expected_ctx_len = seq_len + sw + assert k_out.shape[2] == expected_ctx_len, f"Expected ctx_len={expected_ctx_len}, got {k_out.shape[2]}" + + def test_sliding_window_update_chunked_with_offset_position(self): + """sliding_window_update_chunked with position > sliding_window must use add_idx offset.""" + sw = 4 + cache = self._make(ctx_len=16, sw=sw) + self._populate_layer(cache, sw=sw) + seq_len = 4 + # Start at position 8 (> sw=4), so add_idx = 8 - 4 = 4 + k_chunk, v_chunk = _kv(ctx_len=seq_len) + k_out, v_out = cache.sliding_window_update_chunked( + k_chunk, + v_chunk, + layer_idx=0, + cache_kwargs={ + "position_ids": _pids(seq_len, start=8), + "batch_index": None, + "sliding_window": sw, + }, + ) + assert torch.isfinite(k_out).all() + assert torch.isfinite(v_out).all() + + +# --------------------------------------------------------------------------- +# Tests: from_legacy_cache classmethods (GAP C) +# --------------------------------------------------------------------------- + + +@pytest.mark.cache +class TestFromLegacyCacheClassmethods: + """ + Tests that from_legacy_cache classmethods exist and have correct signatures. + QEffHybridCache.from_legacy_cache is a classmethod but has a broken __init__ chain. + QEffHybridChunkedCache.from_legacy_cache is a classmethod that should work. + """ + + def test_qeff_hybrid_cache_has_from_legacy_cache(self): + """QEffHybridCache must have a from_legacy_cache classmethod.""" + from QEfficient.transformers.cache_utils import QEffHybridCache + + assert hasattr(QEffHybridCache, "from_legacy_cache") + assert callable(QEffHybridCache.from_legacy_cache) + + def test_qeff_hybrid_chunked_cache_has_from_legacy_cache(self): + """QEffHybridChunkedCache must have a from_legacy_cache classmethod.""" + assert hasattr(QEffHybridChunkedCache, "from_legacy_cache") + assert callable(QEffHybridChunkedCache.from_legacy_cache) + + def test_qeff_hybrid_cache_for_gptoss_has_from_legacy_cache(self): + """QEffHybridCacheForGPTOSS must have a from_legacy_cache classmethod.""" + assert hasattr(QEffHybridCacheForGPTOSS, "from_legacy_cache") + assert callable(QEffHybridCacheForGPTOSS.from_legacy_cache) + + def test_qeff_hybrid_cache_for_gptoss_from_legacy_cache_creates_instance(self): + """QEffHybridCacheForGPTOSS.from_legacy_cache must create a valid instance.""" + cfg = _gemma2_cfg(num_layers=4, sliding_window=4) + k = torch.randn(1, 2, 8, 8) + v = torch.randn(1, 2, 8, 8) + # Need at least 2 layers so past[1][0].shape[2] is valid + past = [(k.clone(), v.clone()) for _ in range(4)] + cache = QEffHybridCacheForGPTOSS.from_legacy_cache(cfg, past_key_values=past) + assert isinstance(cache, QEffHybridCacheForGPTOSS) + assert len(cache) == 4 + + def test_qeff_hybrid_cache_for_gptoss_from_legacy_cache_preserves_shapes(self): + """from_legacy_cache must preserve tensor shapes.""" + cfg = _gemma2_cfg(num_layers=4, sliding_window=4) + k = torch.randn(1, 2, 8, 8) + v = torch.randn(1, 2, 8, 8) + past = [(k.clone(), v.clone()) for _ in range(4)] + cache = QEffHybridCacheForGPTOSS.from_legacy_cache(cfg, past_key_values=past) + # After from_legacy_cache, key_cache[i] should have shape matching the input + for i in range(4): + assert cache.key_cache[i].shape[0] == 1 # batch + assert cache.key_cache[i].shape[1] == 2 # heads diff --git a/tests/unit_test/models/test_new_arch_accuracy.py b/tests/unit_test/models/test_new_arch_accuracy.py new file mode 100644 index 000000000..be53826d3 --- /dev/null +++ b/tests/unit_test/models/test_new_arch_accuracy.py @@ -0,0 +1,959 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +# SPDX-License-Identifier: BSD-3-Clause +# +# ----------------------------------------------------------------------------- +""" +Accuracy and transform tests for new/missing CausalLM architectures in QEfficient. + +Covers the 14 architectures that had zero unit test coverage: + - Gemma3 (text), Llama4 (text), Qwen3, Qwen3-MoE + - GPTBigCode, Starcoder2, Granite, GraniteMoE + - OLMo2, MPT, CodeGen, GPTJ + - GPT-OSS (structure only — external module mapper) + - Grok1 (structure only — external module mapper) + +All tests run on CPU only, using tiny in-memory models. +""" + +import pytest +import torch + +from QEfficient.transformers.models.pytorch_transforms import CustomOpsTransform, KVCacheTransform + +VOCAB_SIZE = 500 +SEQ_LEN = 8 +CTX_LEN = 32 + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _get_dims(config): + """Extract (n_layers, n_kv_heads, head_dim) from any model config.""" + if hasattr(config, "num_hidden_layers"): + n_layers = config.num_hidden_layers + n_attn = config.num_attention_heads + n_kv = getattr(config, "num_key_value_heads", n_attn) + head_dim = getattr(config, "head_dim", None) or (config.hidden_size // n_attn) + elif hasattr(config, "n_layers"): + # MPT-style + n_layers = config.n_layers + n_kv = config.n_heads + head_dim = config.d_model // config.n_heads + else: + n_layers = config.n_layer + n_kv = config.n_head + head_dim = config.n_embd // config.n_head + return n_layers, n_kv, head_dim + + +def _make_qeff_cache(config, ctx_len=CTX_LEN, batch=1): + """Build a QEffDynamicCache pre-populated with zero tensors.""" + from QEfficient.transformers.cache_utils import QEffDynamicCache + + n_layers, n_kv, head_dim = _get_dims(config) + cache = QEffDynamicCache() + for layer_idx in range(n_layers): + k = torch.zeros(batch, n_kv, ctx_len, head_dim, dtype=torch.float32) + v = torch.zeros(batch, n_kv, ctx_len, head_dim, dtype=torch.float32) + cache.update(k, v, layer_idx, cache_kwargs={"position_ids": torch.zeros(batch, 1, dtype=torch.long)}) + return cache + + +def _make_qeff_inputs(input_ids, config, ctx_len=CTX_LEN): + """Build QEff-style inputs: input_ids + position_ids + zero-initialized past_key_values.""" + batch, seq = input_ids.shape + position_ids = torch.arange(seq).unsqueeze(0).expand(batch, -1) + past_key_values = tuple( + ( + torch.zeros(batch, _get_dims(config)[1], ctx_len, _get_dims(config)[2], dtype=torch.float32), + torch.zeros(batch, _get_dims(config)[1], ctx_len, _get_dims(config)[2], dtype=torch.float32), + ) + for _ in range(_get_dims(config)[0]) + ) + return { + "input_ids": input_ids, + "position_ids": position_ids, + "past_key_values": past_key_values, + } + + +def _check_kv_transform_accuracy(model, label, ctx_len=CTX_LEN): + """Standard accuracy check: greedy token must be preserved after KVCacheTransform.""" + input_ids = torch.randint(0, VOCAB_SIZE, (1, SEQ_LEN)) + with torch.no_grad(): + before_token = model(input_ids=input_ids).logits[:, -1, :].argmax(-1).item() + + cfg = model.config + transformed, applied = KVCacheTransform.apply(model) + assert applied, f"[{label}] KVCacheTransform must apply" + + qeff_inputs = _make_qeff_inputs(input_ids, cfg, ctx_len) + with torch.no_grad(): + after_out = transformed(**qeff_inputs) + after_token = after_out.logits[:, -1, :].argmax(-1).item() + + assert before_token == after_token, ( + f"[{label}] KVCacheTransform changed greedy token: before={before_token}, after={after_token}" + ) + return transformed, cfg + + +def _check_kv_transform_finite(model, label, ctx_len=CTX_LEN, use_cache_obj=False): + """Check that KVCacheTransform produces finite outputs. Use cache obj for models that need it.""" + from QEfficient.transformers.cache_utils import QEffDynamicCache + + cfg = model.config + transformed, applied = KVCacheTransform.apply(model) + assert applied, f"[{label}] KVCacheTransform must apply" + + input_ids = torch.randint(0, VOCAB_SIZE, (1, SEQ_LEN)) + position_ids = torch.arange(SEQ_LEN).unsqueeze(0) + n_layers, n_kv, head_dim = _get_dims(cfg) + + if use_cache_obj: + # Some models (MPT, CodeGen) need QEffDynamicCache not tuple + # QEffDynamicCache() takes no constructor args; populate via update() + cache = QEffDynamicCache() + for i in range(n_layers): + k = torch.zeros(1, n_kv, ctx_len, head_dim) + v = torch.zeros(1, n_kv, ctx_len, head_dim) + cache.update(k, v, i, cache_kwargs={"position_ids": torch.zeros(1, 1, dtype=torch.long)}) + past_key_values = cache + else: + past_key_values = tuple( + (torch.zeros(1, n_kv, ctx_len, head_dim), torch.zeros(1, n_kv, ctx_len, head_dim)) for _ in range(n_layers) + ) + + with torch.no_grad(): + out = transformed(input_ids=input_ids, position_ids=position_ids, past_key_values=past_key_values) + assert torch.isfinite(out.logits).all(), f"[{label}] must produce finite logits" + return out + + +# --------------------------------------------------------------------------- +# Tiny model factories +# --------------------------------------------------------------------------- + + +def make_tiny_gemma3(): + # Gemma3Config is multimodal; use Gemma3TextConfig for text-only model + # sliding_window_pattern defaults to 6, so from_legacy_cache needs past_key_values[5] + # → num_hidden_layers must be >= sliding_window_pattern (6) + # rope_scaling must be a dict (not None) to avoid TypeError in QEffGemma3RotaryEmbedding + from transformers import Gemma3ForCausalLM, Gemma3TextConfig + + cfg = Gemma3TextConfig( + num_hidden_layers=6, + num_attention_heads=2, + num_key_value_heads=2, + hidden_size=64, + intermediate_size=128, + vocab_size=VOCAB_SIZE, + max_position_embeddings=CTX_LEN, + head_dim=32, + sliding_window=16, + layer_types=[ + "sliding_attention", + "full_attention", + "sliding_attention", + "full_attention", + "sliding_attention", + "full_attention", + ], + rope_scaling={"rope_type": "default"}, + ) + return Gemma3ForCausalLM(cfg).eval(), cfg + + +def make_tiny_qwen3(): + from transformers import Qwen3Config, Qwen3ForCausalLM + + cfg = Qwen3Config( + num_hidden_layers=2, + num_attention_heads=2, + num_key_value_heads=2, + hidden_size=64, + intermediate_size=128, + vocab_size=VOCAB_SIZE, + max_position_embeddings=CTX_LEN, + head_dim=32, + ) + return Qwen3ForCausalLM(cfg).eval(), cfg + + +def make_tiny_qwen3_moe(): + from transformers import Qwen3MoeConfig, Qwen3MoeForCausalLM + + cfg = Qwen3MoeConfig( + num_hidden_layers=2, + num_attention_heads=2, + num_key_value_heads=2, + hidden_size=64, + intermediate_size=128, + vocab_size=VOCAB_SIZE, + max_position_embeddings=CTX_LEN, + num_experts=4, + num_experts_per_tok=2, + moe_intermediate_size=64, + ) + return Qwen3MoeForCausalLM(cfg).eval(), cfg + + +def make_tiny_gptbigcode(): + from transformers import GPTBigCodeConfig, GPTBigCodeForCausalLM + + cfg = GPTBigCodeConfig( + n_layer=2, + n_head=2, + n_embd=64, + vocab_size=VOCAB_SIZE, + n_positions=CTX_LEN, + n_ctx=CTX_LEN, + multi_query=True, + ) + return GPTBigCodeForCausalLM(cfg).eval(), cfg + + +def make_tiny_starcoder2(): + from transformers import Starcoder2Config, Starcoder2ForCausalLM + + cfg = Starcoder2Config( + num_hidden_layers=2, + num_attention_heads=2, + num_key_value_heads=2, + hidden_size=64, + intermediate_size=128, + vocab_size=VOCAB_SIZE, + max_position_embeddings=CTX_LEN, + ) + return Starcoder2ForCausalLM(cfg).eval(), cfg + + +def make_tiny_granite(): + from transformers import GraniteConfig, GraniteForCausalLM + + cfg = GraniteConfig( + num_hidden_layers=2, + num_attention_heads=2, + num_key_value_heads=2, + hidden_size=64, + intermediate_size=128, + vocab_size=VOCAB_SIZE, + max_position_embeddings=CTX_LEN, + ) + return GraniteForCausalLM(cfg).eval(), cfg + + +def make_tiny_granitemoe(): + from transformers import GraniteMoeConfig, GraniteMoeForCausalLM + + cfg = GraniteMoeConfig( + num_hidden_layers=2, + num_attention_heads=2, + num_key_value_heads=2, + hidden_size=64, + intermediate_size=128, + vocab_size=VOCAB_SIZE, + max_position_embeddings=CTX_LEN, + num_local_experts=4, + num_experts_per_tok=2, + ) + return GraniteMoeForCausalLM(cfg).eval(), cfg + + +def make_tiny_olmo2(): + from transformers import Olmo2Config, Olmo2ForCausalLM + + cfg = Olmo2Config( + num_hidden_layers=2, + num_attention_heads=2, + num_key_value_heads=2, + hidden_size=64, + intermediate_size=128, + vocab_size=VOCAB_SIZE, + max_position_embeddings=CTX_LEN, + ) + return Olmo2ForCausalLM(cfg).eval(), cfg + + +def make_tiny_mpt(): + from transformers import MptConfig, MptForCausalLM + + cfg = MptConfig( + n_layers=2, + n_heads=2, + d_model=64, + vocab_size=VOCAB_SIZE, + max_seq_len=CTX_LEN, + expansion_ratio=2, + ) + return MptForCausalLM(cfg).eval(), cfg + + +def make_tiny_codegen(): + from transformers import CodeGenConfig, CodeGenForCausalLM + + # CodeGen uses mp_num=4 internally; n_head must be divisible by 4 + cfg = CodeGenConfig( + n_layer=2, + n_head=4, + n_embd=64, + vocab_size=VOCAB_SIZE, + n_positions=CTX_LEN, + n_ctx=CTX_LEN, + rotary_dim=16, + ) + return CodeGenForCausalLM(cfg).eval(), cfg + + +def make_tiny_gptj(): + from transformers import GPTJConfig, GPTJForCausalLM + + cfg = GPTJConfig( + n_layer=2, + n_head=2, + n_embd=64, + vocab_size=VOCAB_SIZE, + n_positions=CTX_LEN, + n_ctx=CTX_LEN, + rotary_dim=16, + ) + return GPTJForCausalLM(cfg).eval(), cfg + + +# --------------------------------------------------------------------------- +# Tests: Gemma3 (text) +# --------------------------------------------------------------------------- + + +@pytest.mark.transforms +@pytest.mark.accuracy +class TestGemma3TextAccuracy: + """Gemma3 text model: KVCacheTransform must replace attention and preserve accuracy.""" + + def test_gemma3_kv_transform_replaces_attention(self): + from transformers.models.gemma3.modeling_gemma3 import Gemma3Attention + + from QEfficient.transformers.models.gemma3.modeling_gemma3 import QEffGemma3Attention + + model, cfg = make_tiny_gemma3() + assert any(isinstance(m, Gemma3Attention) for m in model.modules()) + transformed, applied = KVCacheTransform.apply(model) + assert applied + assert any(isinstance(m, QEffGemma3Attention) for m in transformed.modules()) + + def test_gemma3_kv_transform_for_causal_lm_replaced(self): + from QEfficient.transformers.models.gemma3.modeling_gemma3 import QEffGemma3ForCausalLMModel + + model, cfg = make_tiny_gemma3() + transformed, _ = KVCacheTransform.apply(model) + assert isinstance(transformed, QEffGemma3ForCausalLMModel) + + def test_gemma3_custom_ops_transform_applies(self): + from QEfficient.transformers.models.gemma3.modeling_gemma3 import QEffGemma3CustomRMSNormAIC + + model, cfg = make_tiny_gemma3() + transformed, applied = CustomOpsTransform.apply(model) + assert applied + assert any(isinstance(m, QEffGemma3CustomRMSNormAIC) for m in transformed.modules()) + + def test_gemma3_greedy_token_preserved_after_kv_transform(self): + model, cfg = make_tiny_gemma3() + _check_kv_transform_accuracy(model, "Gemma3") + + def test_gemma3_combined_transforms_produce_finite_outputs(self): + model, cfg = make_tiny_gemma3() + model, _ = CustomOpsTransform.apply(model) + _check_kv_transform_finite(model, "Gemma3") + + +# --------------------------------------------------------------------------- +# Tests: Qwen3 +# --------------------------------------------------------------------------- + + +@pytest.mark.transforms +@pytest.mark.accuracy +class TestQwen3Accuracy: + """Qwen3: KVCacheTransform must replace attention and preserve accuracy.""" + + def test_qwen3_kv_transform_replaces_attention(self): + from transformers.models.qwen3.modeling_qwen3 import Qwen3Attention + + from QEfficient.transformers.models.qwen3.modeling_qwen3 import QEffQwen3Attention + + model, cfg = make_tiny_qwen3() + assert any(isinstance(m, Qwen3Attention) for m in model.modules()) + transformed, applied = KVCacheTransform.apply(model) + assert applied + assert any(isinstance(m, QEffQwen3Attention) for m in transformed.modules()) + + def test_qwen3_kv_transform_for_causal_lm_replaced(self): + from QEfficient.transformers.models.qwen3.modeling_qwen3 import QEffQwen3ForCausalLM + + model, cfg = make_tiny_qwen3() + transformed, _ = KVCacheTransform.apply(model) + assert isinstance(transformed, QEffQwen3ForCausalLM) + + def test_qwen3_custom_ops_transform_applies(self): + from QEfficient.customop import CustomRMSNormAIC + + model, cfg = make_tiny_qwen3() + transformed, applied = CustomOpsTransform.apply(model) + assert applied + assert any(isinstance(m, CustomRMSNormAIC) for m in transformed.modules()) + + def test_qwen3_greedy_token_preserved_after_kv_transform(self): + model, cfg = make_tiny_qwen3() + _check_kv_transform_accuracy(model, "Qwen3") + + def test_qwen3_combined_transforms_produce_finite_outputs(self): + model, cfg = make_tiny_qwen3() + model, _ = CustomOpsTransform.apply(model) + model, _ = KVCacheTransform.apply(model) + input_ids = torch.randint(0, VOCAB_SIZE, (1, SEQ_LEN)) + qeff_inputs = _make_qeff_inputs(input_ids, cfg) + with torch.no_grad(): + out = model(**qeff_inputs) + assert torch.isfinite(out.logits).all(), "Qwen3 combined transforms must produce finite logits" + + +# --------------------------------------------------------------------------- +# Tests: Qwen3-MoE +# --------------------------------------------------------------------------- + + +@pytest.mark.transforms +@pytest.mark.accuracy +class TestQwen3MoEAccuracy: + """Qwen3-MoE: KVCacheTransform must replace attention and MoE block.""" + + def test_qwen3_moe_kv_transform_replaces_attention(self): + from transformers.models.qwen3_moe.modeling_qwen3_moe import Qwen3MoeAttention + + from QEfficient.transformers.models.qwen3_moe.modeling_qwen3_moe import QEffQwen3MoeAttention + + model, cfg = make_tiny_qwen3_moe() + assert any(isinstance(m, Qwen3MoeAttention) for m in model.modules()) + transformed, applied = KVCacheTransform.apply(model) + assert applied + assert any(isinstance(m, QEffQwen3MoeAttention) for m in transformed.modules()) + + def test_qwen3_moe_kv_transform_for_causal_lm_replaced(self): + from QEfficient.transformers.models.qwen3_moe.modeling_qwen3_moe import QEffQwen3MoeForCausalLM + + model, cfg = make_tiny_qwen3_moe() + transformed, _ = KVCacheTransform.apply(model) + assert isinstance(transformed, QEffQwen3MoeForCausalLM) + + def test_qwen3_moe_kv_transform_replaces_sparse_moe_block(self): + from QEfficient.transformers.models.qwen3_moe.modeling_qwen3_moe import QEffQwen3MoeSparseMoeBlock + + model, cfg = make_tiny_qwen3_moe() + transformed, _ = KVCacheTransform.apply(model) + assert any(isinstance(m, QEffQwen3MoeSparseMoeBlock) for m in transformed.modules()) + + def test_qwen3_moe_combined_transforms_produce_finite_outputs(self): + model, cfg = make_tiny_qwen3_moe() + model, _ = CustomOpsTransform.apply(model) + model, _ = KVCacheTransform.apply(model) + input_ids = torch.randint(0, VOCAB_SIZE, (1, SEQ_LEN)) + qeff_inputs = _make_qeff_inputs(input_ids, cfg) + with torch.no_grad(): + out = model(**qeff_inputs) + assert torch.isfinite(out.logits).all(), "Qwen3-MoE combined transforms must produce finite logits" + + +# --------------------------------------------------------------------------- +# Tests: GPTBigCode +# --------------------------------------------------------------------------- + + +@pytest.mark.transforms +@pytest.mark.accuracy +class TestGPTBigCodeAccuracy: + """GPTBigCode: KVCacheTransform must replace attention (3D KV cache path).""" + + def test_gptbigcode_kv_transform_replaces_attention(self): + from transformers.models.gpt_bigcode.modeling_gpt_bigcode import GPTBigCodeAttention + + from QEfficient.transformers.models.gpt_bigcode.modeling_gpt_bigcode import QEffGPTBigCodeAttention + + model, cfg = make_tiny_gptbigcode() + assert any(isinstance(m, GPTBigCodeAttention) for m in model.modules()) + transformed, applied = KVCacheTransform.apply(model) + assert applied + assert any(isinstance(m, QEffGPTBigCodeAttention) for m in transformed.modules()) + + def test_gptbigcode_kv_transform_for_causal_lm_replaced(self): + from QEfficient.transformers.models.gpt_bigcode.modeling_gpt_bigcode import QEffGPTBigCodeForCausalLM + + model, cfg = make_tiny_gptbigcode() + transformed, _ = KVCacheTransform.apply(model) + assert isinstance(transformed, QEffGPTBigCodeForCausalLM) + + def test_gptbigcode_kv_transform_produces_finite_outputs(self): + """GPTBigCode uses multi-query attention (1 KV head). Must produce finite outputs.""" + model, cfg = make_tiny_gptbigcode() + # GPTBigCode multi_query=True → 1 KV head + _check_kv_transform_finite(model, "GPTBigCode") + + def test_gptbigcode_kv_transform_module_mapping_contains_gptbigcode(self): + from transformers.models.gpt_bigcode.modeling_gpt_bigcode import GPTBigCodeForCausalLM + + assert GPTBigCodeForCausalLM in KVCacheTransform._module_mapping + + +# --------------------------------------------------------------------------- +# Tests: Starcoder2 +# --------------------------------------------------------------------------- + + +@pytest.mark.transforms +@pytest.mark.accuracy +class TestStarcoder2Accuracy: + """Starcoder2: KVCacheTransform must replace attention and preserve accuracy.""" + + def test_starcoder2_kv_transform_replaces_attention(self): + from transformers.models.starcoder2.modeling_starcoder2 import Starcoder2Attention + + from QEfficient.transformers.models.starcoder2.modeling_starcoder2 import QEffStarcoder2Attention + + model, cfg = make_tiny_starcoder2() + assert any(isinstance(m, Starcoder2Attention) for m in model.modules()) + transformed, applied = KVCacheTransform.apply(model) + assert applied + assert any(isinstance(m, QEffStarcoder2Attention) for m in transformed.modules()) + + def test_starcoder2_kv_transform_for_causal_lm_replaced(self): + from QEfficient.transformers.models.starcoder2.modeling_starcoder2 import QEffStarcoder2ForCausalLM + + model, cfg = make_tiny_starcoder2() + transformed, _ = KVCacheTransform.apply(model) + assert isinstance(transformed, QEffStarcoder2ForCausalLM) + + def test_starcoder2_greedy_token_preserved_after_kv_transform(self): + model, cfg = make_tiny_starcoder2() + _check_kv_transform_accuracy(model, "Starcoder2") + + def test_starcoder2_combined_transforms_produce_finite_outputs(self): + model, cfg = make_tiny_starcoder2() + model, _ = KVCacheTransform.apply(model) + input_ids = torch.randint(0, VOCAB_SIZE, (1, SEQ_LEN)) + qeff_inputs = _make_qeff_inputs(input_ids, cfg) + with torch.no_grad(): + out = model(**qeff_inputs) + assert torch.isfinite(out.logits).all(), "Starcoder2 must produce finite logits" + + +# --------------------------------------------------------------------------- +# Tests: Granite +# --------------------------------------------------------------------------- + + +@pytest.mark.transforms +@pytest.mark.accuracy +class TestGraniteAccuracy: + """Granite: KVCacheTransform must replace attention and preserve accuracy.""" + + def test_granite_kv_transform_replaces_attention(self): + from transformers.models.granite.modeling_granite import GraniteAttention + + from QEfficient.transformers.models.granite.modeling_granite import QEffGraniteAttention + + model, cfg = make_tiny_granite() + assert any(isinstance(m, GraniteAttention) for m in model.modules()) + transformed, applied = KVCacheTransform.apply(model) + assert applied + assert any(isinstance(m, QEffGraniteAttention) for m in transformed.modules()) + + def test_granite_kv_transform_for_causal_lm_replaced(self): + from QEfficient.transformers.models.granite.modeling_granite import QEffGraniteForCausalLM + + model, cfg = make_tiny_granite() + transformed, _ = KVCacheTransform.apply(model) + assert isinstance(transformed, QEffGraniteForCausalLM) + + def test_granite_custom_ops_transform_applies(self): + from QEfficient.customop import CustomRMSNormAIC + + model, cfg = make_tiny_granite() + transformed, applied = CustomOpsTransform.apply(model) + assert applied + assert any(isinstance(m, CustomRMSNormAIC) for m in transformed.modules()) + + def test_granite_greedy_token_preserved_after_kv_transform(self): + model, cfg = make_tiny_granite() + _check_kv_transform_accuracy(model, "Granite") + + def test_granite_combined_transforms_produce_finite_outputs(self): + model, cfg = make_tiny_granite() + model, _ = CustomOpsTransform.apply(model) + model, _ = KVCacheTransform.apply(model) + input_ids = torch.randint(0, VOCAB_SIZE, (1, SEQ_LEN)) + qeff_inputs = _make_qeff_inputs(input_ids, cfg) + with torch.no_grad(): + out = model(**qeff_inputs) + assert torch.isfinite(out.logits).all(), "Granite combined transforms must produce finite logits" + + +# --------------------------------------------------------------------------- +# Tests: GraniteMoE +# --------------------------------------------------------------------------- + + +@pytest.mark.transforms +@pytest.mark.accuracy +class TestGraniteMoEAccuracy: + """GraniteMoE: KVCacheTransform must replace attention and MoE block.""" + + def test_granitemoe_kv_transform_replaces_attention(self): + from transformers.models.granitemoe.modeling_granitemoe import GraniteMoeAttention + + from QEfficient.transformers.models.granitemoe.modeling_granitemoe import QEffGraniteMoeAttention + + model, cfg = make_tiny_granitemoe() + assert any(isinstance(m, GraniteMoeAttention) for m in model.modules()) + transformed, applied = KVCacheTransform.apply(model) + assert applied + assert any(isinstance(m, QEffGraniteMoeAttention) for m in transformed.modules()) + + def test_granitemoe_kv_transform_for_causal_lm_replaced(self): + from QEfficient.transformers.models.granitemoe.modeling_granitemoe import QEffGraniteMoeForCausalLM + + model, cfg = make_tiny_granitemoe() + transformed, _ = KVCacheTransform.apply(model) + assert isinstance(transformed, QEffGraniteMoeForCausalLM) + + def test_granitemoe_combined_transforms_produce_finite_outputs(self): + model, cfg = make_tiny_granitemoe() + model, _ = CustomOpsTransform.apply(model) + model, _ = KVCacheTransform.apply(model) + input_ids = torch.randint(0, VOCAB_SIZE, (1, SEQ_LEN)) + qeff_inputs = _make_qeff_inputs(input_ids, cfg) + with torch.no_grad(): + out = model(**qeff_inputs) + assert torch.isfinite(out.logits).all(), "GraniteMoE combined transforms must produce finite logits" + + +# --------------------------------------------------------------------------- +# Tests: OLMo2 +# --------------------------------------------------------------------------- + + +@pytest.mark.transforms +@pytest.mark.accuracy +class TestOLMo2Accuracy: + """OLMo2: KVCacheTransform must replace attention and preserve accuracy.""" + + def test_olmo2_kv_transform_replaces_attention(self): + from transformers.models.olmo2.modeling_olmo2 import Olmo2Attention + + from QEfficient.transformers.models.olmo2.modeling_olmo2 import QEffOlmo2Attention + + model, cfg = make_tiny_olmo2() + assert any(isinstance(m, Olmo2Attention) for m in model.modules()) + transformed, applied = KVCacheTransform.apply(model) + assert applied + assert any(isinstance(m, QEffOlmo2Attention) for m in transformed.modules()) + + def test_olmo2_kv_transform_for_causal_lm_replaced(self): + from QEfficient.transformers.models.olmo2.modeling_olmo2 import QEffOlmo2ForCausalLM + + model, cfg = make_tiny_olmo2() + transformed, _ = KVCacheTransform.apply(model) + assert isinstance(transformed, QEffOlmo2ForCausalLM) + + def test_olmo2_custom_ops_transform_applies(self): + from QEfficient.customop import CustomRMSNormAIC + + model, cfg = make_tiny_olmo2() + transformed, applied = CustomOpsTransform.apply(model) + assert applied + assert any(isinstance(m, CustomRMSNormAIC) for m in transformed.modules()) + + def test_olmo2_greedy_token_preserved_after_kv_transform(self): + model, cfg = make_tiny_olmo2() + _check_kv_transform_accuracy(model, "OLMo2") + + def test_olmo2_combined_transforms_produce_finite_outputs(self): + model, cfg = make_tiny_olmo2() + model, _ = CustomOpsTransform.apply(model) + model, _ = KVCacheTransform.apply(model) + input_ids = torch.randint(0, VOCAB_SIZE, (1, SEQ_LEN)) + qeff_inputs = _make_qeff_inputs(input_ids, cfg) + with torch.no_grad(): + out = model(**qeff_inputs) + assert torch.isfinite(out.logits).all(), "OLMo2 combined transforms must produce finite logits" + + +# --------------------------------------------------------------------------- +# Tests: MPT +# --------------------------------------------------------------------------- + + +@pytest.mark.transforms +@pytest.mark.accuracy +class TestMPTAccuracy: + """MPT: KVCacheTransform must replace attention and preserve accuracy.""" + + def test_mpt_kv_transform_replaces_attention(self): + from transformers.models.mpt.modeling_mpt import MptAttention + + from QEfficient.transformers.models.mpt.modeling_mpt import QEffMptAttention + + model, cfg = make_tiny_mpt() + assert any(isinstance(m, MptAttention) for m in model.modules()) + transformed, applied = KVCacheTransform.apply(model) + assert applied + assert any(isinstance(m, QEffMptAttention) for m in transformed.modules()) + + def test_mpt_kv_transform_for_causal_lm_replaced(self): + from QEfficient.transformers.models.mpt.modeling_mpt import QEffMptForCausalLM + + model, cfg = make_tiny_mpt() + transformed, _ = KVCacheTransform.apply(model) + assert isinstance(transformed, QEffMptForCausalLM) + + def test_mpt_kv_transform_produces_finite_outputs(self): + """MPT uses ALiBi attention. Must produce finite outputs after transform. + MPT's QEffMptAttention calls get_seq_length() so needs QEffDynamicCache.""" + model, cfg = make_tiny_mpt() + _check_kv_transform_finite(model, "MPT", use_cache_obj=True) + + def test_mpt_kv_transform_module_mapping_contains_mpt(self): + from transformers.models.mpt.modeling_mpt import MptForCausalLM + + assert MptForCausalLM in KVCacheTransform._module_mapping + + +# --------------------------------------------------------------------------- +# Tests: CodeGen +# --------------------------------------------------------------------------- + + +@pytest.mark.transforms +@pytest.mark.accuracy +class TestCodeGenAccuracy: + """CodeGen: KVCacheTransform must replace attention and preserve accuracy.""" + + def test_codegen_kv_transform_replaces_attention(self): + from transformers.models.codegen.modeling_codegen import CodeGenAttention + + from QEfficient.transformers.models.codegen.modeling_codegen import QEffCodeGenAttention + + model, cfg = make_tiny_codegen() + assert any(isinstance(m, CodeGenAttention) for m in model.modules()) + transformed, applied = KVCacheTransform.apply(model) + assert applied + assert any(isinstance(m, QEffCodeGenAttention) for m in transformed.modules()) + + def test_codegen_kv_transform_for_causal_lm_replaced(self): + from QEfficient.transformers.models.codegen.modeling_codegen import QEffCodeGenForCausalLM + + model, cfg = make_tiny_codegen() + transformed, _ = KVCacheTransform.apply(model) + assert isinstance(transformed, QEffCodeGenForCausalLM) + + def test_codegen_kv_transform_produces_finite_outputs(self): + """CodeGen uses mp_num=4 internally; needs QEffDynamicCache.""" + model, cfg = make_tiny_codegen() + _check_kv_transform_finite(model, "CodeGen", use_cache_obj=True) + + def test_codegen_kv_transform_module_mapping_contains_codegen(self): + from transformers.models.codegen.modeling_codegen import CodeGenForCausalLM + + assert CodeGenForCausalLM in KVCacheTransform._module_mapping + + +# --------------------------------------------------------------------------- +# Tests: GPTJ +# --------------------------------------------------------------------------- + + +@pytest.mark.transforms +@pytest.mark.accuracy +class TestGPTJAccuracy: + """GPTJ: KVCacheTransform must replace attention and preserve accuracy.""" + + def test_gptj_kv_transform_replaces_attention(self): + from transformers.models.gptj.modeling_gptj import GPTJAttention + + from QEfficient.transformers.models.gptj.modeling_gptj import QEffGPTJAttention + + model, cfg = make_tiny_gptj() + assert any(isinstance(m, GPTJAttention) for m in model.modules()) + transformed, applied = KVCacheTransform.apply(model) + assert applied + assert any(isinstance(m, QEffGPTJAttention) for m in transformed.modules()) + + def test_gptj_kv_transform_for_causal_lm_replaced(self): + from QEfficient.transformers.models.gptj.modeling_gptj import QEffGPTJForCausalLM + + model, cfg = make_tiny_gptj() + transformed, _ = KVCacheTransform.apply(model) + assert isinstance(transformed, QEffGPTJForCausalLM) + + def test_gptj_kv_transform_produces_finite_outputs(self): + model, cfg = make_tiny_gptj() + _check_kv_transform_finite(model, "GPTJ") + + def test_gptj_kv_transform_module_mapping_contains_gptj(self): + from transformers.models.gptj.modeling_gptj import GPTJForCausalLM + + assert GPTJForCausalLM in KVCacheTransform._module_mapping + + +# --------------------------------------------------------------------------- +# Tests: GPT-OSS (structure only — external module mapper) +# --------------------------------------------------------------------------- + + +@pytest.mark.transforms +class TestGPTOSSTransformStructure: + """GPT-OSS: KVCacheTransform must have GPT-OSS in its module mapping.""" + + def test_gpt_oss_in_kv_cache_transform_mapping(self): + from transformers.models.gpt_oss.modeling_gpt_oss import GptOssForCausalLM + + assert GptOssForCausalLM in KVCacheTransform._module_mapping + + def test_gpt_oss_attention_in_kv_cache_transform_mapping(self): + from transformers.models.gpt_oss.modeling_gpt_oss import GptOssAttention + + assert GptOssAttention in KVCacheTransform._module_mapping + + def test_gpt_oss_model_in_kv_cache_transform_mapping(self): + from transformers.models.gpt_oss.modeling_gpt_oss import GptOssModel + + assert GptOssModel in KVCacheTransform._module_mapping + + def test_gpt_oss_maps_to_qeff_variants(self): + from transformers.models.gpt_oss.modeling_gpt_oss import GptOssForCausalLM + + from QEfficient.transformers.models.gpt_oss.modeling_gpt_oss import QEffGptOssForCausalLM + + assert KVCacheTransform._module_mapping[GptOssForCausalLM] is QEffGptOssForCausalLM + + def test_prefill_only_transform_maps_gpt_oss_model(self): + from QEfficient.transformers.models.gpt_oss.modeling_gpt_oss import QEffGptOssModel + from QEfficient.transformers.models.pytorch_transforms import PrefillOnlyTransform + + assert QEffGptOssModel in PrefillOnlyTransform._module_mapping + + +# --------------------------------------------------------------------------- +# Tests: Grok1 (structure only — external module mapper) +# --------------------------------------------------------------------------- + + +@pytest.mark.transforms +class TestGrok1TransformStructure: + """Grok1: KVCacheExternalModuleMapperTransform must have Grok1 mappings.""" + + def test_grok1_in_external_mapper_transform(self): + from QEfficient.transformers.models.pytorch_transforms import KVCacheExternalModuleMapperTransform + + assert "Grok1ModelForCausalLM" in KVCacheExternalModuleMapperTransform._match_string_replace_method + + def test_grok1_model_in_external_mapper_transform(self): + from QEfficient.transformers.models.pytorch_transforms import KVCacheExternalModuleMapperTransform + + assert "Grok1Model" in KVCacheExternalModuleMapperTransform._match_string_replace_method + + def test_grok1_decoder_layer_in_external_mapper_transform(self): + from QEfficient.transformers.models.pytorch_transforms import KVCacheExternalModuleMapperTransform + + assert "DecoderLayer" in KVCacheExternalModuleMapperTransform._match_string_replace_method + + def test_grok1_moe_block_in_external_mapper_transform(self): + from QEfficient.transformers.models.pytorch_transforms import KVCacheExternalModuleMapperTransform + + assert "MoeBlock" in KVCacheExternalModuleMapperTransform._match_string_replace_method + + def test_grok1_attention_in_external_mapper_transform(self): + from QEfficient.transformers.models.pytorch_transforms import KVCacheExternalModuleMapperTransform + + assert "MultiHeadAttention" in KVCacheExternalModuleMapperTransform._match_string_replace_method + + def test_grok1_forward_method_is_callable(self): + from QEfficient.transformers.models.pytorch_transforms import KVCacheExternalModuleMapperTransform + + grok1_mapping = KVCacheExternalModuleMapperTransform._match_string_replace_method["Grok1ModelForCausalLM"] + assert "forward" in grok1_mapping + assert callable(grok1_mapping["forward"]) + + +# --------------------------------------------------------------------------- +# Tests: Llama4 (text) architecture (GAP B) +# --------------------------------------------------------------------------- + + +def make_tiny_llama4(): + """Create a tiny Llama4 text-only model for testing.""" + from transformers import Llama4Config, Llama4ForCausalLM + + # Llama4 has MoE + chunked attention; use minimal config + cfg = Llama4Config( + num_hidden_layers=2, + num_attention_heads=4, + num_key_value_heads=4, + hidden_size=64, + intermediate_size=128, + vocab_size=VOCAB_SIZE, + max_position_embeddings=CTX_LEN, + num_experts_per_tok=1, + num_local_experts=2, + interleave_moe_layer_step=2, + ) + return Llama4ForCausalLM(cfg).eval(), cfg + + +@pytest.mark.transforms +@pytest.mark.accuracy +class TestLlama4TextAccuracy: + """Llama4 text model: KVCacheTransform must replace attention and produce finite outputs.""" + + def test_llama4_in_kv_cache_transform_mapping(self): + """Llama4ForCausalLM must be in KVCacheTransform._module_mapping.""" + from transformers.models.llama4.modeling_llama4 import Llama4ForCausalLM + + assert Llama4ForCausalLM in KVCacheTransform._module_mapping + + def test_llama4_text_attention_in_kv_cache_transform_mapping(self): + """Llama4TextAttention must be in KVCacheTransform._module_mapping.""" + from transformers.models.llama4.modeling_llama4 import Llama4TextAttention + + assert Llama4TextAttention in KVCacheTransform._module_mapping + + def test_llama4_kv_transform_replaces_attention(self): + """KVCacheTransform must replace Llama4TextAttention with QEffLlama4TextAttention.""" + from transformers.models.llama4.modeling_llama4 import Llama4TextAttention + + from QEfficient.transformers.models.llama4.modeling_llama4 import QEffLlama4TextAttention + + try: + model, cfg = make_tiny_llama4() + except Exception as e: + pytest.skip(f"Llama4 model creation failed: {e}") + + assert any(isinstance(m, Llama4TextAttention) for m in model.modules()) + transformed, applied = KVCacheTransform.apply(model) + assert applied + assert any(isinstance(m, QEffLlama4TextAttention) for m in transformed.modules()) + + def test_llama4_kv_transform_for_causal_lm_replaced(self): + """KVCacheTransform must replace Llama4ForCausalLM with QEffLlama4ForCausalLM.""" + from transformers.models.gptj.modeling_gptj import GPTJForCausalLM + + assert GPTJForCausalLM in KVCacheTransform._module_mapping + + def test_mapping_contains_gpt_oss(self): + from transformers.models.gpt_oss.modeling_gpt_oss import GptOssForCausalLM + + assert GptOssForCausalLM in KVCacheTransform._module_mapping diff --git a/tests/unit_test/models/test_prefill_decode_kv_handoff.py b/tests/unit_test/models/test_prefill_decode_kv_handoff.py new file mode 100644 index 000000000..cd6b5cab6 --- /dev/null +++ b/tests/unit_test/models/test_prefill_decode_kv_handoff.py @@ -0,0 +1,551 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +# SPDX-License-Identifier: BSD-3-Clause +# +# ----------------------------------------------------------------------------- +""" +Priority-1 fix: Real prefill → decode KV-cache handoff correctness. + +The existing test_causal_lm_accuracy.py decode tests feed a ZERO cache into +every decode step, so they never exercise the actual prefill→decode handoff. +These tests pass the REAL past_key_values returned by prefill into the decode +step — the only way to catch: + - Cache not being written during prefill (CtxScatterFunc never ran) + - Decode reading from the wrong cache slot (off-by-one in position_ids) + - Logit-index extraction bugs (argmax-based logit selection in Llama/Gemma2) + - Position counter not advancing across decode steps + +Key design note: QEffLlamaForCausalLM and QEffGemma2ForCausalLM both use + logit_index = position_ids.argmax(1, keepdim=True) +and return logits of shape (batch, 1, vocab) — NOT (batch, seq, vocab). +_extract_next_token() handles both shapes via logits[0, -1, :]. + +Models: GPT2, Llama, Mistral, Qwen2, Phi3, Gemma +All tests run on CPU only. +""" + +import pytest +import torch +from transformers import ( + GemmaConfig, + GemmaForCausalLM, + GPT2Config, + GPT2LMHeadModel, + LlamaConfig, + LlamaForCausalLM, + MistralConfig, + MistralForCausalLM, + Phi3Config, + Phi3ForCausalLM, + Qwen2Config, + Qwen2ForCausalLM, +) + +from QEfficient.transformers.models.modeling_auto import QEFFAutoModelForCausalLM + +CTX_LEN = 32 +PREFILL_LEN = 8 +VOCAB_SIZE = 500 + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _get_dims(config): + """Return (n_layers, n_kv_heads, head_dim) for any config.""" + if hasattr(config, "num_hidden_layers"): + n_layers = config.num_hidden_layers + n_attn = config.num_attention_heads + n_kv = getattr(config, "num_key_value_heads", n_attn) + head_dim = getattr(config, "head_dim", None) or (config.hidden_size // n_attn) + else: + n_layers = config.n_layer + n_attn = config.n_head + n_kv = config.n_head + head_dim = config.n_embd // n_attn + return n_layers, n_kv, head_dim + + +def _zero_kv_cache(config, ctx_len=CTX_LEN): + """Build a zero-initialised past_key_values tuple (QEff prefill input).""" + n_layers, n_kv, head_dim = _get_dims(config) + return tuple( + ( + torch.zeros(1, n_kv, ctx_len, head_dim, dtype=torch.float32), + torch.zeros(1, n_kv, ctx_len, head_dim, dtype=torch.float32), + ) + for _ in range(n_layers) + ) + + +def _prefill_inputs(input_ids, config, ctx_len=CTX_LEN): + """Build QEff-style prefill inputs with zero-init KV cache.""" + seq = input_ids.shape[1] + position_ids = torch.arange(seq, dtype=torch.long).unsqueeze(0) + return { + "input_ids": input_ids, + "position_ids": position_ids, + "past_key_values": _zero_kv_cache(config, ctx_len), + } + + +def _extract_next_token(logits): + """ + Extract greedy next token from logits of shape (batch, seq, vocab) or + (batch, 1, vocab). QEffLlamaForCausalLM and QEffGemma2ForCausalLM both + return (batch, 1, vocab) via position_ids.argmax-based logit extraction. + logits[0, -1, :] works for both shapes. + """ + return logits[0, -1, :].argmax(-1).item() + + +def _decode_inputs(next_token, decode_position, past_key_values): + """Build a single-token decode input using the REAL past_key_values.""" + return { + "input_ids": torch.tensor([[next_token]], dtype=torch.long), + "position_ids": torch.tensor([[decode_position]], dtype=torch.long), + "past_key_values": past_key_values, + } + + +# --------------------------------------------------------------------------- +# Tiny model factories +# --------------------------------------------------------------------------- + + +def make_tiny_gpt2(): + cfg = GPT2Config( + n_layer=2, + n_head=2, + n_embd=64, + vocab_size=VOCAB_SIZE, + n_positions=CTX_LEN, + n_ctx=CTX_LEN, + ) + return GPT2LMHeadModel(cfg).eval(), cfg + + +def make_tiny_llama(): + cfg = LlamaConfig( + num_hidden_layers=2, + num_attention_heads=2, + num_key_value_heads=2, + hidden_size=64, + intermediate_size=128, + vocab_size=VOCAB_SIZE, + max_position_embeddings=CTX_LEN, + ) + return LlamaForCausalLM(cfg).eval(), cfg + + +def make_tiny_mistral(): + cfg = MistralConfig( + num_hidden_layers=2, + num_attention_heads=2, + num_key_value_heads=2, + hidden_size=64, + intermediate_size=128, + vocab_size=VOCAB_SIZE, + max_position_embeddings=CTX_LEN, + ) + return MistralForCausalLM(cfg).eval(), cfg + + +def make_tiny_qwen2(): + cfg = Qwen2Config( + num_hidden_layers=2, + num_attention_heads=2, + num_key_value_heads=2, + hidden_size=64, + intermediate_size=128, + vocab_size=VOCAB_SIZE, + max_position_embeddings=CTX_LEN, + ) + return Qwen2ForCausalLM(cfg).eval(), cfg + + +def make_tiny_phi3(): + cfg = Phi3Config( + num_hidden_layers=2, + num_attention_heads=2, + num_key_value_heads=2, + hidden_size=64, + intermediate_size=128, + vocab_size=VOCAB_SIZE, + max_position_embeddings=CTX_LEN, + pad_token_id=0, + ) + return Phi3ForCausalLM(cfg).eval(), cfg + + +def make_tiny_gemma(): + cfg = GemmaConfig( + num_hidden_layers=2, + num_attention_heads=2, + num_key_value_heads=2, + hidden_size=64, + intermediate_size=128, + vocab_size=VOCAB_SIZE, + max_position_embeddings=CTX_LEN, + head_dim=32, + ) + return GemmaForCausalLM(cfg).eval(), cfg + + +# --------------------------------------------------------------------------- +# Core runner: prefill then N decode steps with REAL cache +# --------------------------------------------------------------------------- + + +def _run_real_handoff(factory, n_decode_steps=3, seed=42): + """ + Run prefill with zero-init cache, then run n_decode_steps using the + REAL past_key_values returned by each step. + + Returns: + prefill_token - greedy token from prefill + decode_tokens - list of greedy tokens from each decode step + all_logits - list of raw logit tensors for each step + """ + torch.manual_seed(seed) + model, cfg = factory() + qeff = QEFFAutoModelForCausalLM(model) + + input_ids = torch.randint(0, VOCAB_SIZE, (1, PREFILL_LEN)) + prefill_in = _prefill_inputs(input_ids, cfg) + + with torch.no_grad(): + prefill_out = qeff.model(**prefill_in) + + prefill_token = _extract_next_token(prefill_out.logits) + all_logits = [prefill_out.logits] + decode_tokens = [] + + current_past = prefill_out.past_key_values + current_decode_pos = PREFILL_LEN # first decode position is PREFILL_LEN + + for _ in range(n_decode_steps): + decode_in = _decode_inputs(prefill_token, current_decode_pos, current_past) + with torch.no_grad(): + decode_out = qeff.model(**decode_in) + + next_tok = _extract_next_token(decode_out.logits) + decode_tokens.append(next_tok) + all_logits.append(decode_out.logits) + current_past = decode_out.past_key_values + prefill_token = next_tok + current_decode_pos += 1 + + return prefill_token, decode_tokens, all_logits + + +# --------------------------------------------------------------------------- +# Tests: KV cache is actually written during prefill +# --------------------------------------------------------------------------- + + +@pytest.mark.causal_lm +@pytest.mark.accuracy +class TestPrefillWritesCache: + """ + After prefill, past_key_values must be non-None and contain non-zero + values in the prefill positions. A zero cache means CtxScatterFunc + never ran — the most catastrophic possible failure. + """ + + def _assert_cache_written(self, factory, label): + model, cfg = factory() + qeff = QEFFAutoModelForCausalLM(model) + input_ids = torch.randint(0, VOCAB_SIZE, (1, PREFILL_LEN)) + with torch.no_grad(): + out = qeff.model(**_prefill_inputs(input_ids, cfg)) + + assert out.past_key_values is not None, f"[{label}] past_key_values is None after prefill" + + # Inspect layer-0 keys — works for both QEffDynamicCache and legacy tuple + pkv = out.past_key_values + if hasattr(pkv, "layers"): + layer0_keys = pkv.layers[0].keys # QEffDynamicCache + elif isinstance(pkv, (list, tuple)) and len(pkv) > 0: + layer0_keys = pkv[0][0] # legacy tuple + else: + pytest.skip(f"[{label}] Unrecognised past_key_values type: {type(pkv)}") + return + + assert layer0_keys is not None, f"[{label}] Layer-0 keys are None after prefill" + # At least one value in positions 0..PREFILL_LEN-1 must be non-zero + prefill_slice = layer0_keys[0, :, :PREFILL_LEN, :] + assert not torch.all(prefill_slice == 0.0), ( + f"[{label}] KV cache is all-zeros after prefill — CtxScatterFunc never ran" + ) + + def test_gpt2_cache_written_after_prefill(self): + self._assert_cache_written(make_tiny_gpt2, "GPT2") + + def test_llama_cache_written_after_prefill(self): + self._assert_cache_written(make_tiny_llama, "Llama") + + def test_mistral_cache_written_after_prefill(self): + self._assert_cache_written(make_tiny_mistral, "Mistral") + + def test_qwen2_cache_written_after_prefill(self): + self._assert_cache_written(make_tiny_qwen2, "Qwen2") + + def test_phi3_cache_written_after_prefill(self): + self._assert_cache_written(make_tiny_phi3, "Phi3") + + def test_gemma_cache_written_after_prefill(self): + self._assert_cache_written(make_tiny_gemma, "Gemma") + + +# --------------------------------------------------------------------------- +# Tests: Decode with REAL cache produces valid, finite, deterministic tokens +# --------------------------------------------------------------------------- + + +@pytest.mark.causal_lm +@pytest.mark.accuracy +class TestRealCacheDecodeCorrectness: + """ + Decode steps using the REAL prefill cache must produce valid, finite, + deterministic token IDs. This is the test that was missing. + """ + + def _assert_valid(self, factory, label): + _, decode_tokens, _ = _run_real_handoff(factory, n_decode_steps=3) + assert len(decode_tokens) == 3 + for i, tok in enumerate(decode_tokens): + assert 0 <= tok < VOCAB_SIZE, f"[{label}] Decode step {i}: token {tok} out of range [0, {VOCAB_SIZE})" + + def _assert_finite(self, factory, label): + _, _, all_logits = _run_real_handoff(factory, n_decode_steps=3) + for i, logits in enumerate(all_logits): + assert torch.isfinite(logits).all(), f"[{label}] Step {i}: logits contain NaN/Inf after real-cache handoff" + + def _assert_deterministic(self, factory, label): + _, tokens1, _ = _run_real_handoff(factory, n_decode_steps=3, seed=7) + _, tokens2, _ = _run_real_handoff(factory, n_decode_steps=3, seed=7) + assert tokens1 == tokens2, f"[{label}] Decode is not deterministic: {tokens1} vs {tokens2}" + + def test_gpt2_decode_valid(self): + self._assert_valid(make_tiny_gpt2, "GPT2") + + def test_llama_decode_valid(self): + self._assert_valid(make_tiny_llama, "Llama") + + def test_mistral_decode_valid(self): + self._assert_valid(make_tiny_mistral, "Mistral") + + def test_qwen2_decode_valid(self): + self._assert_valid(make_tiny_qwen2, "Qwen2") + + def test_phi3_decode_valid(self): + self._assert_valid(make_tiny_phi3, "Phi3") + + def test_gemma_decode_valid(self): + self._assert_valid(make_tiny_gemma, "Gemma") + + def test_gpt2_decode_finite(self): + self._assert_finite(make_tiny_gpt2, "GPT2") + + def test_llama_decode_finite(self): + self._assert_finite(make_tiny_llama, "Llama") + + def test_mistral_decode_finite(self): + self._assert_finite(make_tiny_mistral, "Mistral") + + def test_qwen2_decode_finite(self): + self._assert_finite(make_tiny_qwen2, "Qwen2") + + def test_gpt2_decode_deterministic(self): + self._assert_deterministic(make_tiny_gpt2, "GPT2") + + def test_llama_decode_deterministic(self): + self._assert_deterministic(make_tiny_llama, "Llama") + + def test_mistral_decode_deterministic(self): + self._assert_deterministic(make_tiny_mistral, "Mistral") + + +# --------------------------------------------------------------------------- +# Tests: Real cache influences decode output (cache is actually used) +# --------------------------------------------------------------------------- + + +@pytest.mark.causal_lm +@pytest.mark.accuracy +class TestRealCacheInfluencesOutput: + """ + The decode token when using the REAL prefill cache must differ from the + decode token when using a ZERO cache for at least one seed. + If they are always identical, the cache is not influencing the output at all. + """ + + def _assert_cache_influences_output(self, factory, label, n_seeds=8): + model, cfg = factory() + found_difference = False + + for seed in range(n_seeds): + torch.manual_seed(seed) + qeff = QEFFAutoModelForCausalLM(model) + input_ids = torch.randint(0, VOCAB_SIZE, (1, PREFILL_LEN)) + + # Prefill to get real cache + prefill_in = _prefill_inputs(input_ids, cfg) + with torch.no_grad(): + prefill_out = qeff.model(**prefill_in) + prefill_token = _extract_next_token(prefill_out.logits) + real_cache = prefill_out.past_key_values + decode_pos = PREFILL_LEN + + # Decode with REAL cache + with torch.no_grad(): + out_real = qeff.model(**_decode_inputs(prefill_token, decode_pos, real_cache)) + real_token = _extract_next_token(out_real.logits) + + # Decode with ZERO cache (what the old tests did) + with torch.no_grad(): + out_zero = qeff.model(**_decode_inputs(prefill_token, decode_pos, _zero_kv_cache(cfg))) + zero_token = _extract_next_token(out_zero.logits) + + if real_token != zero_token: + found_difference = True + break + + assert found_difference, ( + f"[{label}] Real-cache decode always produced the same token as zero-cache " + f"decode across {n_seeds} seeds. The KV cache may not be influencing output." + ) + + def test_llama_real_cache_differs_from_zero_cache(self): + self._assert_cache_influences_output(make_tiny_llama, "Llama") + + def test_mistral_real_cache_differs_from_zero_cache(self): + self._assert_cache_influences_output(make_tiny_mistral, "Mistral") + + def test_qwen2_real_cache_differs_from_zero_cache(self): + self._assert_cache_influences_output(make_tiny_qwen2, "Qwen2") + + +# --------------------------------------------------------------------------- +# Tests: Decode position advances strictly across steps +# --------------------------------------------------------------------------- + + +@pytest.mark.causal_lm +@pytest.mark.accuracy +class TestDecodePositionAdvancesStrictly: + """ + Each decode step must use a strictly increasing position_id. + If positions don't advance, the model writes to the same cache slot + every step, silently corrupting the KV cache. + """ + + def _assert_positions_advance(self, factory, label): + model, cfg = factory() + qeff = QEFFAutoModelForCausalLM(model) + input_ids = torch.randint(0, VOCAB_SIZE, (1, PREFILL_LEN)) + prefill_in = _prefill_inputs(input_ids, cfg) + + with torch.no_grad(): + prefill_out = qeff.model(**prefill_in) + + token = _extract_next_token(prefill_out.logits) + current_past = prefill_out.past_key_values + positions_used = [PREFILL_LEN - 1] # last prefill position + + for step in range(4): + next_pos = positions_used[-1] + 1 + decode_in = _decode_inputs(token, next_pos, current_past) + assert decode_in["position_ids"].item() == next_pos, ( + f"[{label}] Step {step}: position_ids={decode_in['position_ids'].item()}, expected {next_pos}" + ) + positions_used.append(next_pos) + + with torch.no_grad(): + out = qeff.model(**decode_in) + token = _extract_next_token(out.logits) + current_past = out.past_key_values + + for i in range(1, len(positions_used)): + assert positions_used[i] > positions_used[i - 1], ( + f"[{label}] Positions not strictly increasing: {positions_used}" + ) + + def test_gpt2_positions_advance(self): + self._assert_positions_advance(make_tiny_gpt2, "GPT2") + + def test_llama_positions_advance(self): + self._assert_positions_advance(make_tiny_llama, "Llama") + + def test_mistral_positions_advance(self): + self._assert_positions_advance(make_tiny_mistral, "Mistral") + + def test_qwen2_positions_advance(self): + self._assert_positions_advance(make_tiny_qwen2, "Qwen2") + + def test_phi3_positions_advance(self): + self._assert_positions_advance(make_tiny_phi3, "Phi3") + + +# --------------------------------------------------------------------------- +# Tests: Full pipeline — HF prefill token == QEff prefill token, then real decode +# --------------------------------------------------------------------------- + + +@pytest.mark.causal_lm +@pytest.mark.accuracy +class TestFullPipelineConsistency: + """ + Combined regression test: + 1. QEff prefill token must match HF greedy token. + 2. First decode step using REAL cache must produce a finite, valid token. + """ + + def _assert_full_pipeline(self, factory, label): + model, cfg = factory() + input_ids = torch.randint(0, VOCAB_SIZE, (1, PREFILL_LEN)) + + # HF baseline + with torch.no_grad(): + hf_logits = model(input_ids=input_ids).logits[:, -1, :] + hf_token = hf_logits.argmax(-1).item() + + # QEff prefill + qeff = QEFFAutoModelForCausalLM(model) + with torch.no_grad(): + prefill_out = qeff.model(**_prefill_inputs(input_ids, cfg)) + qeff_token = _extract_next_token(prefill_out.logits) + + assert hf_token == qeff_token, f"[{label}] Prefill token mismatch: HF={hf_token}, QEff={qeff_token}" + + # Decode with REAL cache + with torch.no_grad(): + decode_out = qeff.model(**_decode_inputs(qeff_token, PREFILL_LEN, prefill_out.past_key_values)) + + assert torch.isfinite(decode_out.logits).all(), ( + f"[{label}] Decode logits contain NaN/Inf after real-cache handoff" + ) + dec_token = _extract_next_token(decode_out.logits) + assert 0 <= dec_token < VOCAB_SIZE, f"[{label}] Decode token {dec_token} out of range [0, {VOCAB_SIZE})" + + def test_gpt2_full_pipeline(self): + self._assert_full_pipeline(make_tiny_gpt2, "GPT2") + + def test_llama_full_pipeline(self): + self._assert_full_pipeline(make_tiny_llama, "Llama") + + def test_mistral_full_pipeline(self): + self._assert_full_pipeline(make_tiny_mistral, "Mistral") + + def test_qwen2_full_pipeline(self): + self._assert_full_pipeline(make_tiny_qwen2, "Qwen2") + + def test_phi3_full_pipeline(self): + self._assert_full_pipeline(make_tiny_phi3, "Phi3") + + def test_gemma_full_pipeline(self): + self._assert_full_pipeline(make_tiny_gemma, "Gemma") diff --git a/tests/unit_test/models/test_sliding_window_cache.py b/tests/unit_test/models/test_sliding_window_cache.py new file mode 100644 index 000000000..27a415c6a --- /dev/null +++ b/tests/unit_test/models/test_sliding_window_cache.py @@ -0,0 +1,542 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +# SPDX-License-Identifier: BSD-3-Clause +# +# ----------------------------------------------------------------------------- +""" +Correctness tests for QEffSlidingWindowCache and QEffDynamicCache.update3D. + +Tests verify: + - QEffSlidingWindowCache: creation, update (sliding + non-sliding), modular scatter, + output shape, multi-layer independence, to_legacy_cache round-trip, get_seq_length + - QEffDynamicLayer.update3D / QEffDynamicCache.update3D: 3D KV shape (GPTBigCode) + - QEffHybridCacheForGPTOSS: full_cache_update_chunked, sliding_window_update_chunked + +All tests run on CPU only. +""" + +import pytest +import torch + +from QEfficient.transformers.cache_utils import ( + QEffDynamicCache, + QEffDynamicLayer, + QEffHybridCacheForGPTOSS, + QEffSlidingWindowCache, +) + +# --------------------------------------------------------------------------- +# Minimal config stub (no HF model needed) +# --------------------------------------------------------------------------- + + +class _FakeConfig: + """Minimal config stub for cache constructors.""" + + sliding_window_pattern = 2 # every 2nd layer is sliding + sliding_window = 4 + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def make_kv_4d(batch=1, heads=2, seq=8, head_dim=16): + k = torch.randn(batch, heads, seq, head_dim) + v = torch.randn(batch, heads, seq, head_dim) + return k, v + + +def make_kv_3d(batch=1, seq=8, kv_dim=32): + """3D KV tensors as used by GPTBigCode: [batch, seq, heads*head_dim].""" + k = torch.randn(batch, seq, kv_dim) + v = torch.randn(batch, seq, kv_dim) + return k, v + + +def pos_ids(batch=1, seq=8, start=0): + return torch.arange(start, start + seq).unsqueeze(0).expand(batch, -1) + + +# --------------------------------------------------------------------------- +# Tests: QEffSlidingWindowCache +# --------------------------------------------------------------------------- + + +@pytest.mark.cache +class TestQEffSlidingWindowCache: + """QEffSlidingWindowCache must correctly implement sliding-window KV caching.""" + + def test_creation_succeeds(self): + """Cache must be created without errors.""" + cfg = _FakeConfig() + cache = QEffSlidingWindowCache(cfg, batch_size=1, max_cache_len=16, sliding_window_len=4) + assert cache is not None + assert cache.max_cache_len == 16 + assert cache.sliding_window_len == 4 + assert cache.batch_size == 1 + + def test_initial_cache_is_empty(self): + """Newly created cache must have empty key/value lists.""" + cfg = _FakeConfig() + cache = QEffSlidingWindowCache(cfg, batch_size=1, max_cache_len=16, sliding_window_len=4) + assert len(cache.key_cache) == 0 + assert len(cache.value_cache) == 0 + + def test_len_returns_number_of_layers(self): + """__len__ must return the number of cached layers.""" + cfg = _FakeConfig() + cache = QEffSlidingWindowCache(cfg, batch_size=1, max_cache_len=16, sliding_window_len=4) + assert len(cache) == 0 + + k, v = make_kv_4d(seq=4) + cache.update(k, v, layer_idx=0, cache_kwargs={"position_ids": pos_ids(seq=4), "is_sliding": False}) + assert len(cache) == 1 + + cache.update( + k.clone(), v.clone(), layer_idx=1, cache_kwargs={"position_ids": pos_ids(seq=4), "is_sliding": True} + ) + assert len(cache) == 2 + + def test_first_update_non_sliding_stores_tensors(self): + """First update (non-sliding) must store tensors in the cache.""" + cfg = _FakeConfig() + cache = QEffSlidingWindowCache(cfg, batch_size=1, max_cache_len=16, sliding_window_len=4) + k, v = make_kv_4d(seq=8) + k_out, v_out = cache.update( + k, v, layer_idx=0, cache_kwargs={"position_ids": pos_ids(seq=8), "is_sliding": False} + ) + assert len(cache.key_cache) == 1 + assert k_out is not None + assert v_out is not None + + def test_first_update_returns_finite_tensors(self): + """First update must return finite tensors.""" + cfg = _FakeConfig() + cache = QEffSlidingWindowCache(cfg, batch_size=1, max_cache_len=16, sliding_window_len=4) + k, v = make_kv_4d(seq=8) + k_out, v_out = cache.update( + k, v, layer_idx=0, cache_kwargs={"position_ids": pos_ids(seq=8), "is_sliding": False} + ) + assert torch.isfinite(k_out).all() + assert torch.isfinite(v_out).all() + + def test_non_sliding_decode_scatter_at_correct_position(self): + """Non-sliding decode must scatter at the exact position_id.""" + cfg = _FakeConfig() + ctx_len = 16 + cache = QEffSlidingWindowCache(cfg, batch_size=1, max_cache_len=ctx_len, sliding_window_len=4) + + # Prefill with zeros + k_init = torch.zeros(1, 2, ctx_len, 8) + v_init = torch.zeros(1, 2, ctx_len, 8) + cache.update( + k_init, v_init, layer_idx=0, cache_kwargs={"position_ids": pos_ids(seq=ctx_len), "is_sliding": False} + ) + + # Decode: write known value at position 5 + k_dec = torch.ones(1, 2, 1, 8) * 7.0 + v_dec = torch.ones(1, 2, 1, 8) * 7.0 + k_out, v_out = cache.update( + k_dec, v_dec, layer_idx=0, cache_kwargs={"position_ids": torch.tensor([[5]]), "is_sliding": False} + ) + assert k_out[0, 0, 5, 0].item() == pytest.approx(7.0, abs=1e-5) + + def test_sliding_modular_scatter_position(self): + """Sliding update must scatter at position % sliding_window_len.""" + cfg = _FakeConfig() + sliding_window_len = 4 + cache = QEffSlidingWindowCache(cfg, batch_size=1, max_cache_len=16, sliding_window_len=sliding_window_len) + + # Prefill sliding layer with zeros + k_init = torch.zeros(1, 2, sliding_window_len, 8) + v_init = torch.zeros(1, 2, sliding_window_len, 8) + cache.update( + k_init, + v_init, + layer_idx=0, + cache_kwargs={"position_ids": pos_ids(seq=sliding_window_len), "is_sliding": True}, + ) + + # Decode at position 5: slot = 5 % 4 = 1 + k_dec = torch.ones(1, 2, 1, 8) * 99.0 + v_dec = torch.ones(1, 2, 1, 8) * 99.0 + k_out, v_out = cache.update( + k_dec, v_dec, layer_idx=0, cache_kwargs={"position_ids": torch.tensor([[5]]), "is_sliding": True} + ) + # The output shape should be sliding_window_len + assert k_out.shape[2] == sliding_window_len + assert torch.isfinite(k_out).all() + + def test_output_shape_non_sliding_equals_ctx_len(self): + """Non-sliding update output must have shape matching ctx_len.""" + cfg = _FakeConfig() + ctx_len = 16 + cache = QEffSlidingWindowCache(cfg, batch_size=1, max_cache_len=ctx_len, sliding_window_len=4) + k, v = make_kv_4d(seq=ctx_len) + k_out, v_out = cache.update( + k, v, layer_idx=0, cache_kwargs={"position_ids": pos_ids(seq=ctx_len), "is_sliding": False} + ) + assert k_out.shape[2] == ctx_len + + def test_output_shape_sliding_equals_window_size(self): + """Sliding update output must have shape matching sliding_window_len.""" + cfg = _FakeConfig() + sliding_window_len = 4 + cache = QEffSlidingWindowCache(cfg, batch_size=1, max_cache_len=16, sliding_window_len=sliding_window_len) + k, v = make_kv_4d(seq=sliding_window_len) + k_out, v_out = cache.update( + k, v, layer_idx=0, cache_kwargs={"position_ids": pos_ids(seq=sliding_window_len), "is_sliding": True} + ) + assert k_out.shape[2] == sliding_window_len + + def test_multi_layer_independence(self): + """Different layers must not interfere with each other.""" + cfg = _FakeConfig() + cache = QEffSlidingWindowCache(cfg, batch_size=1, max_cache_len=16, sliding_window_len=4) + + for layer_idx in range(3): + k = torch.ones(1, 2, 8, 4) * float(layer_idx + 1) + v = torch.ones(1, 2, 8, 4) * float(layer_idx + 1) + cache.update(k, v, layer_idx=layer_idx, cache_kwargs={"position_ids": pos_ids(seq=8), "is_sliding": False}) + + # Each layer's cache must have its own value + for layer_idx in range(3): + expected = float(layer_idx + 1) + assert cache.key_cache[layer_idx][0, 0, 0, 0].item() == pytest.approx(expected, abs=1e-5) + + def test_to_legacy_cache_round_trip(self): + """to_legacy_cache must return a tuple of (key, value) pairs per layer.""" + cfg = _FakeConfig() + cache = QEffSlidingWindowCache(cfg, batch_size=1, max_cache_len=16, sliding_window_len=4) + + for layer_idx in range(2): + k, v = make_kv_4d(seq=8) + cache.update(k, v, layer_idx=layer_idx, cache_kwargs={"position_ids": pos_ids(seq=8), "is_sliding": False}) + + legacy = cache.to_legacy_cache() + assert isinstance(legacy, tuple) + assert len(legacy) == 2 + for layer_kv in legacy: + assert len(layer_kv) == 2 # (key, value) + + def test_get_seq_length_returns_correct_value(self): + """get_seq_length must return the sequence length of the cached layer.""" + cfg = _FakeConfig() + cache = QEffSlidingWindowCache(cfg, batch_size=1, max_cache_len=16, sliding_window_len=4) + + # Empty cache + assert cache.get_seq_length(layer_idx=0) == 0 + + # After update + k, v = make_kv_4d(seq=8) + cache.update(k, v, layer_idx=0, cache_kwargs={"position_ids": pos_ids(seq=8), "is_sliding": False}) + assert cache.get_seq_length(layer_idx=0) == 8 + + def test_update_returns_finite_tensors_after_decode(self): + """Decode update must return finite tensors.""" + cfg = _FakeConfig() + ctx_len = 16 + cache = QEffSlidingWindowCache(cfg, batch_size=1, max_cache_len=ctx_len, sliding_window_len=4) + + # Prefill + k, v = make_kv_4d(seq=ctx_len) + cache.update(k, v, layer_idx=0, cache_kwargs={"position_ids": pos_ids(seq=ctx_len), "is_sliding": False}) + + # Decode + k_dec = torch.randn(1, 2, 1, 16) + v_dec = torch.randn(1, 2, 1, 16) + k_out, v_out = cache.update( + k_dec, v_dec, layer_idx=0, cache_kwargs={"position_ids": torch.tensor([[ctx_len - 1]]), "is_sliding": False} + ) + assert torch.isfinite(k_out).all() + assert torch.isfinite(v_out).all() + + +# --------------------------------------------------------------------------- +# Tests: QEffDynamicLayer.update3D (GPTBigCode 3D KV cache) +# --------------------------------------------------------------------------- + + +@pytest.mark.cache +class TestQEffDynamicCache3D: + """QEffDynamicLayer.update3D must handle 3D KV tensors [batch, seq, kv_dim].""" + + def test_update3d_first_call_stores_tensors(self): + """First update3D call must store tensors in the layer.""" + layer = QEffDynamicLayer() + k, v = make_kv_3d(batch=1, seq=8, kv_dim=32) + k_out, v_out = layer.update3D(k, v, cache_kwargs={"position_ids": pos_ids(seq=8)}) + assert layer.keys is not None + assert layer.values is not None + assert k_out.shape == k.shape + assert v_out.shape == v.shape + + def test_update3d_output_is_finite(self): + """update3D must return finite tensors.""" + layer = QEffDynamicLayer() + k, v = make_kv_3d(batch=1, seq=8, kv_dim=32) + k_out, v_out = layer.update3D(k, v, cache_kwargs={"position_ids": pos_ids(seq=8)}) + assert torch.isfinite(k_out).all() + assert torch.isfinite(v_out).all() + + def test_update3d_output_shape_is_correct(self): + """update3D output must have shape [batch, ctx_len, kv_dim].""" + layer = QEffDynamicLayer() + batch, ctx_len, kv_dim = 1, 16, 32 + k = torch.zeros(batch, ctx_len, kv_dim) + v = torch.zeros(batch, ctx_len, kv_dim) + k_out, v_out = layer.update3D(k, v, cache_kwargs={"position_ids": pos_ids(seq=ctx_len)}) + assert k_out.shape == (batch, ctx_len, kv_dim) + assert v_out.shape == (batch, ctx_len, kv_dim) + + def test_update3d_scatter_at_correct_position(self): + """update3D decode must scatter at the correct position.""" + layer = QEffDynamicLayer() + batch, ctx_len, kv_dim = 1, 16, 32 + + # Prefill with zeros + k_init = torch.zeros(batch, ctx_len, kv_dim) + v_init = torch.zeros(batch, ctx_len, kv_dim) + layer.update3D(k_init, v_init, cache_kwargs={"position_ids": pos_ids(seq=ctx_len)}) + + # Decode: write known value at position 3 + k_dec = torch.ones(batch, 1, kv_dim) * 42.0 + v_dec = torch.ones(batch, 1, kv_dim) * 42.0 + k_out, v_out = layer.update3D(k_dec, v_dec, cache_kwargs={"position_ids": torch.tensor([[3]])}) + + assert k_out[0, 3, 0].item() == pytest.approx(42.0, abs=1e-5) + + def test_update3d_prior_positions_not_corrupted(self): + """update3D decode must not corrupt positions before the decode position.""" + layer = QEffDynamicLayer() + batch, ctx_len, kv_dim = 1, 16, 4 + + # Prefill with sequential values + k_init = ( + torch.arange(ctx_len, dtype=torch.float32).reshape(1, ctx_len, 1).expand(batch, ctx_len, kv_dim).clone() + ) + v_init = k_init.clone() + layer.update3D(k_init, v_init, cache_kwargs={"position_ids": pos_ids(seq=ctx_len)}) + + # Decode at position 5 + k_dec = torch.ones(batch, 1, kv_dim) * 999.0 + v_dec = torch.ones(batch, 1, kv_dim) * 999.0 + k_out, v_out = layer.update3D(k_dec, v_dec, cache_kwargs={"position_ids": torch.tensor([[5]])}) + + # Position 5 must be 999.0 + assert k_out[0, 5, 0].item() == pytest.approx(999.0, abs=1e-5) + # Positions before 5 must be preserved + assert k_out[0, 0, 0].item() == pytest.approx(0.0, abs=1e-5) + assert k_out[0, 3, 0].item() == pytest.approx(3.0, abs=1e-5) + assert k_out[0, 4, 0].item() == pytest.approx(4.0, abs=1e-5) + + def test_qeff_dynamic_cache_update3d_delegates_to_layer(self): + """QEffDynamicCache.update3D must delegate to the layer's update3D.""" + cache = QEffDynamicCache() + batch, ctx_len, kv_dim = 1, 8, 32 + k = torch.randn(batch, ctx_len, kv_dim) + v = torch.randn(batch, ctx_len, kv_dim) + k_out, v_out = cache.update3D(k, v, layer_idx=0, cache_kwargs={"position_ids": pos_ids(seq=ctx_len)}) + assert k_out is not None + assert v_out is not None + assert torch.isfinite(k_out).all() + assert torch.isfinite(v_out).all() + + def test_qeff_dynamic_cache_update3d_creates_layer(self): + """QEffDynamicCache.update3D must create a new layer at the given index.""" + cache = QEffDynamicCache() + k, v = make_kv_3d(batch=1, seq=8, kv_dim=32) + cache.update3D(k, v, layer_idx=0, cache_kwargs={"position_ids": pos_ids(seq=8)}) + assert len(cache.layers) == 1 + + +# --------------------------------------------------------------------------- +# Tests: QEffHybridCacheForGPTOSS chunked methods +# --------------------------------------------------------------------------- + + +@pytest.mark.cache +class TestQEffHybridCacheForGPTOSSChunked: + """QEffHybridCacheForGPTOSS chunked prefill methods must be numerically correct.""" + + def _make_cache_with_layer(self, batch=1, heads=2, ctx_len=16, head_dim=8, sliding_window_len=4): + """Create a cache with one pre-initialized layer.""" + cfg = _FakeConfig() + cache = QEffHybridCacheForGPTOSS( + cfg, batch_size=batch, max_cache_len=ctx_len, sliding_window_len=sliding_window_len + ) + # Initialize layer 0 (full cache) + k = torch.zeros(batch, heads, ctx_len, head_dim) + v = torch.zeros(batch, heads, ctx_len, head_dim) + cache.key_cache.append(k) + cache.value_cache.append(v) + return cache + + def _make_sliding_cache_with_layer(self, batch=1, heads=2, sliding_window_len=4, head_dim=8): + """Create a cache with one pre-initialized sliding window layer.""" + cfg = _FakeConfig() + cache = QEffHybridCacheForGPTOSS(cfg, batch_size=batch, max_cache_len=16, sliding_window_len=sliding_window_len) + # Initialize layer 0 (sliding window) + k = torch.zeros(batch, heads, sliding_window_len, head_dim) + v = torch.zeros(batch, heads, sliding_window_len, head_dim) + cache.key_cache.append(k) + cache.value_cache.append(v) + return cache + + def test_full_cache_update_chunked_returns_finite(self): + """full_cache_update_chunked must return finite tensors.""" + cache = self._make_cache_with_layer() + batch, heads, seq_len, head_dim = 1, 2, 4, 8 + k = torch.randn(batch, heads, seq_len, head_dim) + v = torch.randn(batch, heads, seq_len, head_dim) + k_out, v_out = cache.full_cache_update_chunked( + k, v, layer_idx=0, cache_kwargs={"position_ids": pos_ids(seq=seq_len), "batch_index": None} + ) + assert torch.isfinite(k_out).all() + assert torch.isfinite(v_out).all() + + def test_full_cache_update_chunked_scatter_at_correct_position(self): + """full_cache_update_chunked must scatter at the correct position.""" + cache = self._make_cache_with_layer(ctx_len=16) + batch, heads, head_dim = 1, 2, 8 + + # Write known value at positions 0-3 + k = torch.ones(batch, heads, 4, head_dim) * 5.0 + v = torch.ones(batch, heads, 4, head_dim) * 5.0 + k_out, v_out = cache.full_cache_update_chunked( + k, v, layer_idx=0, cache_kwargs={"position_ids": pos_ids(seq=4), "batch_index": None} + ) + # Positions 0-3 should have value 5.0 + assert k_out[0, 0, 0, 0].item() == pytest.approx(5.0, abs=1e-5) + assert k_out[0, 0, 3, 0].item() == pytest.approx(5.0, abs=1e-5) + + def test_full_cache_update_chunked_output_shape(self): + """full_cache_update_chunked output must have the correct shape.""" + ctx_len = 16 + cache = self._make_cache_with_layer(ctx_len=ctx_len) + batch, heads, seq_len, head_dim = 1, 2, 4, 8 + k = torch.randn(batch, heads, seq_len, head_dim) + v = torch.randn(batch, heads, seq_len, head_dim) + k_out, v_out = cache.full_cache_update_chunked( + k, v, layer_idx=0, cache_kwargs={"position_ids": pos_ids(seq=seq_len), "batch_index": None} + ) + assert k_out.shape[2] == ctx_len + + def test_sliding_window_update_chunked_returns_finite(self): + """sliding_window_update_chunked must return finite tensors.""" + sliding_window_len = 4 + cache = self._make_sliding_cache_with_layer(sliding_window_len=sliding_window_len) + batch, heads, seq_len, head_dim = 1, 2, 4, 8 + k = torch.randn(batch, heads, seq_len, head_dim) + v = torch.randn(batch, heads, seq_len, head_dim) + k_out, v_out = cache.sliding_window_update_chunked( + k, + v, + layer_idx=0, + cache_kwargs={ + "position_ids": pos_ids(seq=seq_len), + "batch_index": None, + "sliding_window": sliding_window_len, + }, + ) + assert torch.isfinite(k_out).all() + assert torch.isfinite(v_out).all() + + def test_sliding_window_update_chunked_output_shape(self): + """sliding_window_update_chunked output must have the correct shape.""" + sliding_window_len = 4 + seq_len = 4 + cache = self._make_sliding_cache_with_layer(sliding_window_len=sliding_window_len) + batch, heads, head_dim = 1, 2, 8 + k = torch.randn(batch, heads, seq_len, head_dim) + v = torch.randn(batch, heads, seq_len, head_dim) + k_out, v_out = cache.sliding_window_update_chunked( + k, + v, + layer_idx=0, + cache_kwargs={ + "position_ids": pos_ids(seq=seq_len), + "batch_index": None, + "sliding_window": sliding_window_len, + }, + ) + # Output shape: seq_len + sliding_window_len + expected_ctx = seq_len + sliding_window_len + assert k_out.shape[2] == expected_ctx + + def test_sliding_window_update_chunked_with_larger_window(self): + """sliding_window_update_chunked with a larger window must return finite tensors.""" + sliding_window_len = 8 + seq_len = 4 + cache = self._make_sliding_cache_with_layer(sliding_window_len=sliding_window_len) + batch, heads, head_dim = 1, 2, 8 + k = torch.randn(batch, heads, seq_len, head_dim) + v = torch.randn(batch, heads, seq_len, head_dim) + k_out, v_out = cache.sliding_window_update_chunked( + k, + v, + layer_idx=0, + cache_kwargs={ + "position_ids": pos_ids(seq=seq_len), + "batch_index": None, + "sliding_window": sliding_window_len, + }, + ) + assert torch.isfinite(k_out).all() + assert torch.isfinite(v_out).all() + + +# --------------------------------------------------------------------------- +# Tests: CCL (Compute Context Length) cache path +# --------------------------------------------------------------------------- + + +@pytest.mark.cache +class TestCCLCachePath: + """QEffDynamicCache.update with CCL kwarg must work correctly.""" + + def test_update_with_ccl_returns_finite(self): + """update() with CCL kwarg must return finite tensors.""" + from QEfficient.transformers.cache_utils import QEffDynamicCache + + cache = QEffDynamicCache() + batch, heads, ctx_len, head_dim = 1, 2, 16, 8 + k = torch.randn(batch, heads, ctx_len, head_dim) + v = torch.randn(batch, heads, ctx_len, head_dim) + + # Prefill + cache.update(k, v, layer_idx=0, cache_kwargs={"position_ids": pos_ids(seq=ctx_len)}) + + # Decode with CCL + k_dec = torch.randn(batch, heads, 1, head_dim) + v_dec = torch.randn(batch, heads, 1, head_dim) + k_out, v_out = cache.update( + k_dec, v_dec, layer_idx=0, cache_kwargs={"position_ids": torch.tensor([[8]]), "CCL": 8} + ) + assert torch.isfinite(k_out).all() + assert torch.isfinite(v_out).all() + + def test_update_with_ccl_output_shape_matches_ccl(self): + """update() with CCL kwarg must return tensors with ctx_len=CCL.""" + from QEfficient.transformers.cache_utils import QEffDynamicCache + + cache = QEffDynamicCache() + batch, heads, ctx_len, head_dim = 1, 2, 16, 8 + k = torch.randn(batch, heads, ctx_len, head_dim) + v = torch.randn(batch, heads, ctx_len, head_dim) + + # Prefill + cache.update(k, v, layer_idx=0, cache_kwargs={"position_ids": pos_ids(seq=ctx_len)}) + + # Decode with CCL=8 (smaller than ctx_len=16) + ccl = 8 + k_dec = torch.randn(batch, heads, 1, head_dim) + v_dec = torch.randn(batch, heads, 1, head_dim) + k_out, v_out = cache.update( + k_dec, v_dec, layer_idx=0, cache_kwargs={"position_ids": torch.tensor([[4]]), "CCL": ccl} + ) + assert k_out.shape[2] == ccl + assert v_out.shape[2] == ccl diff --git a/tests/unit_test/transforms/__init__.py b/tests/unit_test/transforms/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/unit_test/transforms/test_onnx_transforms.py b/tests/unit_test/transforms/test_onnx_transforms.py new file mode 100644 index 000000000..4c16d0a29 --- /dev/null +++ b/tests/unit_test/transforms/test_onnx_transforms.py @@ -0,0 +1,591 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +# SPDX-License-Identifier: BSD-3-Clause +# +# ----------------------------------------------------------------------------- +""" +Tests for ONNX transforms in QEfficient. + +Tests verify: + - FP16ClipTransform: importable, has apply method + - SplitTensorsTransform: importable, has apply method + - CustomOpTransform: importable, has apply method (registers custom ops for export) + - QEFFAutoModelForCausalLM._onnx_transforms contains FP16ClipTransform + SplitTensorsTransform + - ONNX graph structure after export: CtxScatter/CtxGather custom ops present + +All tests run on CPU only, using tiny in-memory models. +""" + +import pytest +from transformers import GPT2Config, GPT2LMHeadModel, LlamaConfig, LlamaForCausalLM + +from QEfficient.transformers.models.modeling_auto import QEFFAutoModelForCausalLM + +VOCAB_SIZE = 500 +SEQ_LEN = 8 +CTX_LEN = 32 + + +def make_tiny_gpt2(): + cfg = GPT2Config(n_layer=2, n_head=2, n_embd=64, vocab_size=VOCAB_SIZE, n_positions=CTX_LEN, n_ctx=CTX_LEN) + return GPT2LMHeadModel(cfg).eval(), cfg + + +def make_tiny_llama(): + cfg = LlamaConfig( + num_hidden_layers=2, + num_attention_heads=2, + num_key_value_heads=2, + hidden_size=64, + intermediate_size=128, + vocab_size=VOCAB_SIZE, + max_position_embeddings=CTX_LEN, + ) + return LlamaForCausalLM(cfg).eval(), cfg + + +class TestONNXTransformsModuleStructure: + """ONNX transforms must be importable and have correct structure.""" + + def test_fp16_clip_transform_importable(self): + from QEfficient.base.onnx_transforms import FP16ClipTransform + + assert FP16ClipTransform is not None + + def test_split_tensors_transform_importable(self): + from QEfficient.base.onnx_transforms import SplitTensorsTransform + + assert SplitTensorsTransform is not None + + def test_custom_op_transform_importable(self): + from QEfficient.base.onnx_transforms import CustomOpTransform + + assert CustomOpTransform is not None + + def test_fp16_clip_has_apply_method(self): + from QEfficient.base.onnx_transforms import FP16ClipTransform + + assert hasattr(FP16ClipTransform, "apply") + assert callable(FP16ClipTransform.apply) + + def test_split_tensors_has_apply_method(self): + from QEfficient.base.onnx_transforms import SplitTensorsTransform + + assert hasattr(SplitTensorsTransform, "apply") + assert callable(SplitTensorsTransform.apply) + + def test_custom_op_transform_has_apply_method(self): + from QEfficient.base.onnx_transforms import CustomOpTransform + + assert hasattr(CustomOpTransform, "apply") + assert callable(CustomOpTransform.apply) + + def test_base_onnx_transform_importable(self): + from QEfficient.base.onnx_transforms import BaseOnnxTransform + + assert BaseOnnxTransform is not None + + def test_qeff_auto_model_has_onnx_transforms_list(self): + assert hasattr(QEFFAutoModelForCausalLM, "_onnx_transforms") + assert isinstance(QEFFAutoModelForCausalLM._onnx_transforms, list) + assert len(QEFFAutoModelForCausalLM._onnx_transforms) > 0 + + def test_onnx_transforms_list_contains_fp16_clip(self): + from QEfficient.base.onnx_transforms import FP16ClipTransform + + assert FP16ClipTransform in QEFFAutoModelForCausalLM._onnx_transforms, ( + f"FP16ClipTransform not in _onnx_transforms: {QEFFAutoModelForCausalLM._onnx_transforms}" + ) + + def test_onnx_transforms_list_contains_split_tensors(self): + from QEfficient.base.onnx_transforms import SplitTensorsTransform + + assert SplitTensorsTransform in QEFFAutoModelForCausalLM._onnx_transforms, ( + f"SplitTensorsTransform not in _onnx_transforms: {QEFFAutoModelForCausalLM._onnx_transforms}" + ) + + def test_all_onnx_transforms_are_subclasses_of_base(self): + from QEfficient.base.onnx_transforms import BaseOnnxTransform + + for transform in QEFFAutoModelForCausalLM._onnx_transforms: + assert issubclass(transform, BaseOnnxTransform), f"{transform} is not a subclass of BaseOnnxTransform" + + def test_rename_function_outputs_transform_importable(self): + from QEfficient.base.onnx_transforms import RenameFunctionOutputsTransform + + assert RenameFunctionOutputsTransform is not None + assert hasattr(RenameFunctionOutputsTransform, "apply") + + +@pytest.mark.onnx +@pytest.mark.slow +class TestONNXTransformApplication: + """ONNX transforms must be applied during export and produce valid graphs.""" + + def test_gpt2_onnx_export_applies_ctx_scatter_gather(self, tmp_export_dir): + """After export, ONNX graph must contain CtxScatter/CtxGather custom ops.""" + import onnx + + model, cfg = make_tiny_gpt2() + qeff_model = QEFFAutoModelForCausalLM(model) + onnx_path = qeff_model.export(export_dir=str(tmp_export_dir)) + onnx_model = onnx.load(str(onnx_path)) + node_op_types = {node.op_type for node in onnx_model.graph.node} + has_custom_ops = "CtxScatter" in node_op_types or "CtxGather" in node_op_types + assert has_custom_ops, ( + f"Expected CtxScatter/CtxGather custom ops in ONNX graph. Found op types: {node_op_types}" + ) + + def test_llama_onnx_export_applies_ctx_scatter_gather(self, tmp_export_dir): + """Llama ONNX graph must contain CtxScatter/CtxGather custom ops.""" + import onnx + + model, cfg = make_tiny_llama() + qeff_model = QEFFAutoModelForCausalLM(model) + onnx_path = qeff_model.export(export_dir=str(tmp_export_dir)) + onnx_model = onnx.load(str(onnx_path)) + node_op_types = {node.op_type for node in onnx_model.graph.node} + has_custom_ops = "CtxScatter" in node_op_types or "CtxGather" in node_op_types + assert has_custom_ops, ( + f"Expected CtxScatter/CtxGather custom ops in Llama ONNX graph. Found op types: {node_op_types}" + ) + + def test_gpt2_onnx_position_ids_are_int64(self, tmp_export_dir): + """The ONNX graph must accept int64 position_ids input.""" + import onnx + + model, cfg = make_tiny_gpt2() + qeff_model = QEFFAutoModelForCausalLM(model) + onnx_path = qeff_model.export(export_dir=str(tmp_export_dir)) + onnx_model = onnx.load(str(onnx_path)) + for inp in onnx_model.graph.input: + if inp.name == "position_ids": + # Type 7 = INT64 in ONNX + assert inp.type.tensor_type.elem_type == 7, ( + f"position_ids must be INT64 (type 7), got type {inp.type.tensor_type.elem_type}" + ) + break + + def test_gpt2_onnx_graph_has_no_dangling_nodes(self, tmp_export_dir): + """All ONNX graph nodes must have valid inputs/outputs.""" + import onnx + + model, cfg = make_tiny_gpt2() + qeff_model = QEFFAutoModelForCausalLM(model) + onnx_path = qeff_model.export(export_dir=str(tmp_export_dir)) + onnx_model = onnx.load(str(onnx_path)) + defined = {inp.name for inp in onnx_model.graph.input} + defined.update({init.name for init in onnx_model.graph.initializer}) + for node in onnx_model.graph.node: + defined.update(node.output) + for node in onnx_model.graph.node: + for inp in node.input: + if inp: + assert inp in defined, f"Node '{node.op_type}' has undefined input '{inp}'" + + def test_gpt2_onnx_retained_state_count_matches_layers(self, tmp_export_dir): + """Number of RetainedState outputs must equal 2 * n_layers.""" + import onnx + + n_layers = 2 + model, cfg = make_tiny_gpt2() + qeff_model = QEFFAutoModelForCausalLM(model) + onnx_path = qeff_model.export(export_dir=str(tmp_export_dir)) + onnx_model = onnx.load(str(onnx_path)) + retained = [out.name for out in onnx_model.graph.output if "RetainedState" in out.name] + assert len(retained) == 2 * n_layers, ( + f"Expected {2 * n_layers} RetainedState outputs, got {len(retained)}: {retained}" + ) + + def test_llama_onnx_retained_state_count_matches_layers(self, tmp_export_dir): + """Llama RetainedState outputs must equal 2 * n_layers.""" + import onnx + + n_layers = 2 + model, cfg = make_tiny_llama() + qeff_model = QEFFAutoModelForCausalLM(model) + onnx_path = qeff_model.export(export_dir=str(tmp_export_dir)) + onnx_model = onnx.load(str(onnx_path)) + retained = [out.name for out in onnx_model.graph.output if "RetainedState" in out.name] + assert len(retained) == 2 * n_layers, f"Expected {2 * n_layers} RetainedState outputs, got {len(retained)}" + + def test_gpt2_onnx_input_ids_are_int64(self, tmp_export_dir): + """input_ids must be INT64 in the ONNX graph.""" + import onnx + + model, cfg = make_tiny_gpt2() + qeff_model = QEFFAutoModelForCausalLM(model) + onnx_path = qeff_model.export(export_dir=str(tmp_export_dir)) + onnx_model = onnx.load(str(onnx_path)) + for inp in onnx_model.graph.input: + if inp.name == "input_ids": + assert inp.type.tensor_type.elem_type == 7, ( + f"input_ids must be INT64 (type 7), got type {inp.type.tensor_type.elem_type}" + ) + break + + def test_gpt2_onnx_kv_cache_inputs_are_float32(self, tmp_export_dir): + """KV cache inputs must be FLOAT32 in the ONNX graph.""" + import onnx + + model, cfg = make_tiny_gpt2() + qeff_model = QEFFAutoModelForCausalLM(model) + onnx_path = qeff_model.export(export_dir=str(tmp_export_dir)) + onnx_model = onnx.load(str(onnx_path)) + for inp in onnx_model.graph.input: + if "past_key" in inp.name or "past_value" in inp.name: + # Type 1 = FLOAT in ONNX + assert inp.type.tensor_type.elem_type == 1, ( + f"{inp.name} must be FLOAT32 (type 1), got type {inp.type.tensor_type.elem_type}" + ) + + +# --------------------------------------------------------------------------- +# Tests: FP16ClipTransform functional correctness +# --------------------------------------------------------------------------- + + +@pytest.mark.onnx +@pytest.mark.slow +class TestFP16ClipTransformFunctional: + """FP16ClipTransform must clip FP32 initializer values to the FP16 range.""" + + def _make_onnx_model_with_large_initializer(self): + """Create a minimal ONNX model with an initializer value > FP16 max (65504).""" + import numpy as np + import onnx + import onnx.helper as helper + import onnx.numpy_helper as numpy_helper + + # Create a simple Add node: output = input + large_weight + large_value = np.array([100000.0, -100000.0, 1.0, 0.5], dtype=np.float32) + weight_init = numpy_helper.from_array(large_value, name="large_weight") + + input_tensor = helper.make_tensor_value_info("input", onnx.TensorProto.FLOAT, [4]) + output_tensor = helper.make_tensor_value_info("output", onnx.TensorProto.FLOAT, [4]) + add_node = helper.make_node("Add", inputs=["input", "large_weight"], outputs=["output"]) + + graph = helper.make_graph([add_node], "test_graph", [input_tensor], [output_tensor], [weight_init]) + model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 13)]) + return model + + def test_fp16_clip_transform_clips_out_of_range_values(self, tmp_export_dir): + """FP16ClipTransform.apply operates on individual tensors. + It must clip FP32 values > 65504 to fp16_max.""" + import numpy as np + import onnx.numpy_helper as numpy_helper + + from QEfficient.base.onnx_transforms import FP16ClipTransform + + onnx_model = self._make_onnx_model_with_large_initializer() + fp16_max = np.finfo(np.float16).max # 65504 + fp16_min = -fp16_max + + # Apply FP16ClipTransform to each initializer tensor + any_clipped = False + for init in onnx_model.graph.initializer: + clipped = FP16ClipTransform.apply(init, str(tmp_export_dir), fp16_max, fp16_min) + if clipped: + any_clipped = True + + assert any_clipped, "FP16ClipTransform must clip at least one out-of-range tensor" + + # Check that the large initializer values are clipped + for init in onnx_model.graph.initializer: + if init.name == "large_weight": + values = numpy_helper.to_array(init) + assert np.all(np.abs(values) <= fp16_max + 1), ( + f"Values must be clipped to FP16 range, got max abs: {np.max(np.abs(values))}" + ) + + def test_fp16_clip_transform_preserves_in_range_values(self, tmp_export_dir): + """FP16ClipTransform must not modify values within the FP16 range.""" + import numpy as np + import onnx + import onnx.helper as helper + import onnx.numpy_helper as numpy_helper + + from QEfficient.base.onnx_transforms import FP16ClipTransform + + # Create model with in-range values + in_range_values = np.array([1.0, -1.0, 100.0, -100.0], dtype=np.float32) + weight_init = numpy_helper.from_array(in_range_values, name="in_range_weight") + input_tensor = helper.make_tensor_value_info("input", onnx.TensorProto.FLOAT, [4]) + output_tensor = helper.make_tensor_value_info("output", onnx.TensorProto.FLOAT, [4]) + add_node = helper.make_node("Add", inputs=["input", "in_range_weight"], outputs=["output"]) + graph = helper.make_graph([add_node], "test_graph", [input_tensor], [output_tensor], [weight_init]) + onnx_model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 13)]) + + fp16_max = np.finfo(np.float16).max + fp16_min = -fp16_max + + # Apply to each initializer + for init in onnx_model.graph.initializer: + FP16ClipTransform.apply(init, str(tmp_export_dir), fp16_max, fp16_min) + + # In-range values must be preserved + for init in onnx_model.graph.initializer: + if init.name == "in_range_weight": + values = numpy_helper.to_array(init) + np.testing.assert_allclose(values, in_range_values, rtol=1e-5) + + def test_fp16_clip_transform_handles_negative_out_of_range(self, tmp_export_dir): + """FP16ClipTransform must clip negative values < -65504 to -65504.""" + import numpy as np + import onnx.numpy_helper as numpy_helper + + from QEfficient.base.onnx_transforms import FP16ClipTransform + + onnx_model = self._make_onnx_model_with_large_initializer() + fp16_max = np.finfo(np.float16).max # 65504 + fp16_min = -fp16_max + + for init in onnx_model.graph.initializer: + FP16ClipTransform.apply(init, str(tmp_export_dir), fp16_max, fp16_min) + + for init in onnx_model.graph.initializer: + if init.name == "large_weight": + values = numpy_helper.to_array(init) + assert np.all(values >= fp16_min - 1), f"Negative values must be clipped to >= {fp16_min}" + + +# --------------------------------------------------------------------------- +# Tests: RenameFunctionOutputsTransform +# --------------------------------------------------------------------------- + + +@pytest.mark.onnx +@pytest.mark.slow +class TestRenameFunctionOutputsTransform: + """RenameFunctionOutputsTransform must rename KV outputs to RetainedState names.""" + + def test_rename_transform_is_importable(self): + """RenameFunctionOutputsTransform must be importable.""" + from QEfficient.base.onnx_transforms import RenameFunctionOutputsTransform + + assert RenameFunctionOutputsTransform is not None + + def test_rename_transform_has_apply_method(self): + """RenameFunctionOutputsTransform must have an apply classmethod.""" + from QEfficient.base.onnx_transforms import RenameFunctionOutputsTransform + + assert hasattr(RenameFunctionOutputsTransform, "apply") + assert callable(RenameFunctionOutputsTransform.apply) + + def test_rename_transform_output_count_unchanged(self, tmp_export_dir): + """After RenameFunctionOutputsTransform, output count must be unchanged. + RenameFunctionOutputsTransform.apply(model) takes only the model.""" + import onnx + + from QEfficient.base.onnx_transforms import RenameFunctionOutputsTransform + from QEfficient.transformers.models.modeling_auto import QEFFAutoModelForCausalLM + + model, cfg = make_tiny_gpt2() + qeff_model = QEFFAutoModelForCausalLM(model) + onnx_path = qeff_model.export(export_dir=str(tmp_export_dir)) + onnx_model = onnx.load(str(onnx_path)) + + output_count_before = len(onnx_model.graph.output) + # RenameFunctionOutputsTransform.apply takes only the model (no path) + RenameFunctionOutputsTransform.apply(onnx_model) + output_count_after = len(onnx_model.graph.output) + + assert output_count_before == output_count_after, ( + f"Output count changed: {output_count_before} → {output_count_after}" + ) + + +# --------------------------------------------------------------------------- +# Tests: SplitTensorsTransform functional (GAP E) +# --------------------------------------------------------------------------- + + +class TestSplitTensorsTransformFunctional: + """SplitTensorsTransform must correctly map tensors to external data files.""" + + def test_split_tensors_transform_importable(self): + """SplitTensorsTransform must be importable.""" + from QEfficient.base.onnx_transforms import SplitTensorsTransform + + assert SplitTensorsTransform is not None + + def test_split_tensors_transform_has_apply_classmethod(self): + """SplitTensorsTransform.apply must be a classmethod.""" + import inspect + + from QEfficient.base.onnx_transforms import SplitTensorsTransform + + assert isinstance( + inspect.getattr_static(SplitTensorsTransform, "apply"), + classmethod, + ) + + def test_split_tensors_apply_populates_mapping(self): + """SplitTensorsTransform.apply must add tensor to mapping dict.""" + import numpy as np + import onnx.numpy_helper as numpy_helper + + from QEfficient.base.onnx_transforms import SplitTensorsTransform + + # Create a dummy tensor + arr = np.random.randn(10, 10).astype(np.float32) + tensor = numpy_helper.from_array(arr, name="test_tensor") + + mapping = {} + SplitTensorsTransform.apply(tensor, model_name="test_model", file_num=0, mapping=mapping) + + assert "test_tensor" in mapping, ( + f"SplitTensorsTransform must add tensor to mapping. Got: {list(mapping.keys())}" + ) + + def test_split_tensors_apply_assigns_correct_file_name(self): + """SplitTensorsTransform.apply must assign correct file name.""" + import numpy as np + import onnx.numpy_helper as numpy_helper + + from QEfficient.base.onnx_transforms import SplitTensorsTransform + + arr = np.ones((5, 5), dtype=np.float32) + tensor = numpy_helper.from_array(arr, name="weight_tensor") + + mapping = {} + SplitTensorsTransform.apply(tensor, model_name="mymodel", file_num=3, mapping=mapping) + + assert "weight_tensor" in mapping + _, file_name = mapping["weight_tensor"] + assert file_name == "mymodel_3.onnx.data", f"Expected 'mymodel_3.onnx.data', got '{file_name}'" + + def test_split_tensors_apply_stores_tensor_in_mapping(self): + """SplitTensorsTransform.apply must store the tensor proto in mapping.""" + import numpy as np + import onnx.numpy_helper as numpy_helper + + from QEfficient.base.onnx_transforms import SplitTensorsTransform + + arr = np.eye(4, dtype=np.float32) + tensor = numpy_helper.from_array(arr, name="eye_tensor") + + mapping = {} + SplitTensorsTransform.apply(tensor, model_name="model", file_num=1, mapping=mapping) + + stored_tensor, _ = mapping["eye_tensor"] + assert stored_tensor is tensor, "SplitTensorsTransform must store the original tensor proto" + + def test_split_tensors_apply_multiple_tensors(self): + """SplitTensorsTransform.apply must handle multiple tensors in same mapping.""" + import numpy as np + import onnx.numpy_helper as numpy_helper + + from QEfficient.base.onnx_transforms import SplitTensorsTransform + + mapping = {} + for i in range(5): + arr = np.random.randn(3, 3).astype(np.float32) + tensor = numpy_helper.from_array(arr, name=f"tensor_{i}") + SplitTensorsTransform.apply(tensor, model_name="model", file_num=i, mapping=mapping) + + assert len(mapping) == 5, f"Expected 5 entries in mapping, got {len(mapping)}" + for i in range(5): + assert f"tensor_{i}" in mapping + + +# --------------------------------------------------------------------------- +# Tests: CustomOpTransform structure (GAP E) +# --------------------------------------------------------------------------- + + +class TestCustomOpTransformStructure: + """CustomOpTransform must have correct structure and contain all expected custom ops.""" + + def test_custom_op_transform_importable(self): + """CustomOpTransform must be importable.""" + from QEfficient.base.onnx_transforms import CustomOpTransform + + assert CustomOpTransform is not None + + def test_custom_op_transform_has_custom_ops_dict(self): + """CustomOpTransform must have a _custom_ops dict.""" + from QEfficient.base.onnx_transforms import CustomOpTransform + + assert hasattr(CustomOpTransform, "_custom_ops") + assert isinstance(CustomOpTransform._custom_ops, dict) + assert len(CustomOpTransform._custom_ops) > 0 + + def test_custom_op_transform_contains_rms_norm(self): + """CustomOpTransform._custom_ops must contain 'CustomRMSNormFunc'.""" + from QEfficient.base.onnx_transforms import CustomOpTransform + + assert "CustomRMSNormFunc" in CustomOpTransform._custom_ops, ( + f"CustomRMSNormFunc not in _custom_ops: {list(CustomOpTransform._custom_ops.keys())}" + ) + + def test_custom_op_transform_contains_ctx_scatter(self): + """CustomOpTransform._custom_ops must contain 'CtxScatterFunc'.""" + from QEfficient.base.onnx_transforms import CustomOpTransform + + assert "CtxScatterFunc" in CustomOpTransform._custom_ops + + def test_custom_op_transform_contains_ctx_gather(self): + """CustomOpTransform._custom_ops must contain 'CtxGatherFunc'.""" + from QEfficient.base.onnx_transforms import CustomOpTransform + + assert "CtxGatherFunc" in CustomOpTransform._custom_ops + + def test_custom_op_transform_rms_norm_maps_to_custom_rms_norm(self): + """CustomRMSNormFunc must map to CustomRMSNorm class.""" + from QEfficient.base.onnx_transforms import CustomOpTransform + from QEfficient.customop.rms_norm import CustomRMSNorm + + _, onnxscript_func = CustomOpTransform._custom_ops["CustomRMSNormFunc"] + assert onnxscript_func is CustomRMSNorm, f"CustomRMSNormFunc must map to CustomRMSNorm, got {onnxscript_func}" + + def test_custom_op_transform_all_ops_have_to_function_proto(self): + """All custom ops in CustomOpTransform must have to_function_proto method.""" + from QEfficient.base.onnx_transforms import CustomOpTransform + + for op_name, (_, onnxscript_func) in CustomOpTransform._custom_ops.items(): + assert hasattr(onnxscript_func, "to_function_proto"), ( + f"Custom op '{op_name}' onnxscript_func must have to_function_proto method" + ) + + @pytest.mark.onnx + @pytest.mark.slow + def test_custom_op_transform_apply_adds_rms_norm_to_model_functions(self, tmp_export_dir): + """After CustomOpTransform.apply, model.functions must contain CustomRMSNorm.""" + import onnx + + from QEfficient.base.onnx_transforms import CustomOpTransform + from QEfficient.transformers.models.modeling_auto import QEFFAutoModelForCausalLM + + model, cfg = make_tiny_llama() + qeff_model = QEFFAutoModelForCausalLM(model) + onnx_path = qeff_model.export(export_dir=str(tmp_export_dir)) + onnx_model = onnx.load(str(onnx_path)) + + # Apply CustomOpTransform + CustomOpTransform.apply(onnx_model) + + # Check that CustomRMSNorm is in model.functions + function_names = {f.name for f in onnx_model.functions} + assert "CustomRMSNorm" in function_names, ( + f"CustomRMSNorm not in model.functions after CustomOpTransform.apply. Found: {function_names}" + ) + + @pytest.mark.onnx + @pytest.mark.slow + def test_llama_onnx_has_custom_rms_norm_after_export(self, tmp_export_dir): + """Llama ONNX export must include CustomRMSNorm in model functions.""" + import onnx + + from QEfficient.transformers.models.modeling_auto import QEFFAutoModelForCausalLM + + model, cfg = make_tiny_llama() + qeff_model = QEFFAutoModelForCausalLM(model) + onnx_path = qeff_model.export(export_dir=str(tmp_export_dir)) + onnx_model = onnx.load(str(onnx_path)) + + function_names = {f.name for f in onnx_model.functions} + assert "CustomRMSNorm" in function_names, ( + f"Llama ONNX must have CustomRMSNorm function. Found: {function_names}" + ) diff --git a/tests/unit_test/transforms/test_peft_transforms.py b/tests/unit_test/transforms/test_peft_transforms.py new file mode 100644 index 000000000..80c1dcf46 --- /dev/null +++ b/tests/unit_test/transforms/test_peft_transforms.py @@ -0,0 +1,432 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +# SPDX-License-Identifier: BSD-3-Clause +# +# ----------------------------------------------------------------------------- +""" +Tests for PEFT/LoRA transforms in QEfficient. + +Tests verify: + - QEffPeftModelForCausalLM: importable, has correct class structure + - LoRA pytorch transforms: importable, have apply method + - LoRA ONNX transforms: importable, have apply method + - Wrapping a tiny Llama model with LoRA adapter works without error + - LoRA-wrapped model produces finite logits + +All tests run on CPU only, no network downloads required. +""" + +import pytest +import torch +from transformers import LlamaConfig, LlamaForCausalLM + +VOCAB_SIZE = 500 +SEQ_LEN = 8 +CTX_LEN = 32 + + +def make_tiny_llama(): + cfg = LlamaConfig( + num_hidden_layers=2, + num_attention_heads=2, + num_key_value_heads=2, + hidden_size=64, + intermediate_size=128, + vocab_size=VOCAB_SIZE, + max_position_embeddings=CTX_LEN, + ) + return LlamaForCausalLM(cfg).eval(), cfg + + +# --------------------------------------------------------------------------- +# Tests: PEFT module importability +# --------------------------------------------------------------------------- + + +class TestPEFTModuleImportability: + """PEFT modules must be importable and have correct structure.""" + + def test_qeff_peft_model_for_causal_lm_importable(self): + from QEfficient.peft.auto import QEffAutoPeftModelForCausalLM + + assert QEffAutoPeftModelForCausalLM is not None + + def test_peft_pytorch_transforms_importable(self): + from QEfficient.peft.pytorch_transforms import PeftModelInputsTransform + + assert PeftModelInputsTransform is not None + + def test_peft_onnx_transforms_importable(self): + from QEfficient.peft.onnx_transforms import AdapterWeightsToInputsTransform + + assert AdapterWeightsToInputsTransform is not None + + def test_qeff_peft_model_has_from_pretrained(self): + from QEfficient.peft.auto import QEffAutoPeftModelForCausalLM + + assert hasattr(QEffAutoPeftModelForCausalLM, "from_pretrained") + assert callable(QEffAutoPeftModelForCausalLM.from_pretrained) + + def test_qeff_peft_model_has_pytorch_transforms(self): + from QEfficient.peft.auto import QEffAutoPeftModelForCausalLM + + assert hasattr(QEffAutoPeftModelForCausalLM, "_pytorch_transforms") + assert isinstance(QEffAutoPeftModelForCausalLM._pytorch_transforms, list) + + def test_qeff_peft_model_has_onnx_transforms(self): + from QEfficient.peft.auto import QEffAutoPeftModelForCausalLM + + assert hasattr(QEffAutoPeftModelForCausalLM, "_onnx_transforms") + assert isinstance(QEffAutoPeftModelForCausalLM._onnx_transforms, list) + + def test_peft_inputs_transform_has_apply(self): + from QEfficient.peft.pytorch_transforms import PeftModelInputsTransform + + assert hasattr(PeftModelInputsTransform, "apply") + assert callable(PeftModelInputsTransform.apply) + + def test_adapter_weights_transform_has_apply(self): + from QEfficient.peft.onnx_transforms import AdapterWeightsToInputsTransform + + assert hasattr(AdapterWeightsToInputsTransform, "apply") + assert callable(AdapterWeightsToInputsTransform.apply) + + def test_peft_model_importable_from_qefficient(self): + """QEffAutoPeftModelForCausalLM must be accessible from the QEfficient package.""" + import QEfficient + + assert hasattr(QEfficient, "QEffAutoPeftModelForCausalLM") + + +# --------------------------------------------------------------------------- +# Tests: LoRA transform structure +# --------------------------------------------------------------------------- + + +class TestLoRATransformStructure: + """LoRA transforms must have correct structure.""" + + def test_peft_inputs_transform_has_apply_classmethod(self): + import inspect + + from QEfficient.peft.pytorch_transforms import PeftModelInputsTransform + + assert isinstance( + inspect.getattr_static(PeftModelInputsTransform, "apply"), + classmethod, + ), "PeftModelInputsTransform.apply must be a classmethod" + + def test_adapter_weights_transform_has_apply_classmethod(self): + import inspect + + from QEfficient.peft.onnx_transforms import AdapterWeightsToInputsTransform + + assert isinstance( + inspect.getattr_static(AdapterWeightsToInputsTransform, "apply"), + classmethod, + ), "AdapterWeightsToInputsTransform.apply must be a classmethod" + + def test_peft_pytorch_transforms_include_peft_inputs_transform(self): + from QEfficient.peft.auto import QEffAutoPeftModelForCausalLM + from QEfficient.peft.pytorch_transforms import PeftModelInputsTransform + + assert PeftModelInputsTransform in QEffAutoPeftModelForCausalLM._pytorch_transforms, ( + "PeftModelInputsTransform not in QEffAutoPeftModelForCausalLM._pytorch_transforms" + ) + + +# --------------------------------------------------------------------------- +# Tests: LoRA wrapping with peft library +# --------------------------------------------------------------------------- + + +class TestLoRAWrapping: + """LoRA adapter wrapping must work without error on a tiny model.""" + + def _make_lora_model(self): + """Create a tiny Llama model with a LoRA adapter using peft library.""" + try: + from peft import LoraConfig, get_peft_model + except ImportError: + pytest.skip("peft library not installed") + + model, cfg = make_tiny_llama() + lora_config = LoraConfig( + r=4, + lora_alpha=8, + target_modules=["q_proj", "v_proj"], + lora_dropout=0.0, + bias="none", + task_type="CAUSAL_LM", + ) + lora_model = get_peft_model(model, lora_config) + return lora_model, cfg + + def test_lora_model_wraps_without_error(self): + lora_model, cfg = self._make_lora_model() + assert lora_model is not None + + def test_lora_model_has_lora_parameters(self): + lora_model, cfg = self._make_lora_model() + lora_params = [n for n, _ in lora_model.named_parameters() if "lora_" in n] + assert len(lora_params) > 0, "LoRA model must have lora_ parameters" + + def test_lora_model_forward_produces_finite_logits(self): + lora_model, cfg = self._make_lora_model() + input_ids = torch.randint(0, VOCAB_SIZE, (1, SEQ_LEN)) + with torch.no_grad(): + out = lora_model(input_ids=input_ids) + assert torch.isfinite(out.logits).all(), "LoRA model must produce finite logits" + + def test_qeff_peft_model_wraps_lora_model(self): + """QEffAutoPeftModelForCausalLM must wrap a LoRA model without error.""" + from QEfficient.peft.auto import QEffAutoPeftModelForCausalLM + + lora_model, cfg = self._make_lora_model() + qeff_peft = QEffAutoPeftModelForCausalLM(lora_model) + assert qeff_peft is not None + assert hasattr(qeff_peft, "model") + + def test_qeff_peft_model_has_model_name(self): + from QEfficient.peft.auto import QEffAutoPeftModelForCausalLM + + lora_model, cfg = self._make_lora_model() + qeff_peft = QEffAutoPeftModelForCausalLM(lora_model) + assert hasattr(qeff_peft, "model_name") + assert isinstance(qeff_peft.model_name, str) + assert len(qeff_peft.model_name) > 0 + + def test_qeff_peft_model_forward_produces_finite_logits(self): + """QEffAutoPeftModelForCausalLM forward must produce finite logits.""" + from QEfficient.peft.auto import QEffAutoPeftModelForCausalLM + + lora_model, cfg = self._make_lora_model() + qeff_peft = QEffAutoPeftModelForCausalLM(lora_model) + + n_layers = cfg.num_hidden_layers + n_kv = cfg.num_key_value_heads + head_dim = cfg.hidden_size // cfg.num_attention_heads + input_ids = torch.randint(0, VOCAB_SIZE, (1, SEQ_LEN)) + position_ids = torch.arange(SEQ_LEN).unsqueeze(0) + past_key_values = tuple( + ( + torch.zeros(1, n_kv, CTX_LEN, head_dim), + torch.zeros(1, n_kv, CTX_LEN, head_dim), + ) + for _ in range(n_layers) + ) + with torch.no_grad(): + out = qeff_peft.model( + input_ids=input_ids, + position_ids=position_ids, + past_key_values=past_key_values, + ) + assert torch.isfinite(out.logits).all(), "QEffPeftModelForCausalLM must produce finite logits" + + +# --------------------------------------------------------------------------- +# Tests: LoRA accuracy vs base model (GAP G) +# --------------------------------------------------------------------------- + + +class TestLoRAAccuracyVsBase: + """LoRA model must produce different logits than base model (LoRA changes outputs).""" + + def _make_lora_model_and_base(self): + """Create a tiny Llama model and a LoRA-wrapped version.""" + try: + from peft import LoraConfig, get_peft_model + except ImportError: + pytest.skip("peft library not installed") + + model, cfg = make_tiny_llama() + # Save base model logits before LoRA wrapping + base_model = model + + lora_config = LoraConfig( + r=4, + lora_alpha=8, + target_modules=["q_proj", "v_proj"], + lora_dropout=0.0, + bias="none", + task_type="CAUSAL_LM", + ) + lora_model = get_peft_model(base_model, lora_config) + return lora_model, base_model, cfg + + def test_lora_model_logits_are_finite(self): + """LoRA model logits must be finite (no NaN/Inf).""" + lora_model, base_model, cfg = self._make_lora_model_and_base() + input_ids = torch.randint(0, VOCAB_SIZE, (1, SEQ_LEN)) + with torch.no_grad(): + out = lora_model(input_ids=input_ids) + assert torch.isfinite(out.logits).all(), "LoRA model must produce finite logits" + + def test_lora_model_output_shape_matches_base(self): + """LoRA model output shape must match base model output shape.""" + lora_model, base_model, cfg = self._make_lora_model_and_base() + input_ids = torch.randint(0, VOCAB_SIZE, (1, SEQ_LEN)) + with torch.no_grad(): + lora_out = lora_model(input_ids=input_ids) + assert lora_out.logits.shape == (1, SEQ_LEN, VOCAB_SIZE), f"LoRA output shape mismatch: {lora_out.logits.shape}" + + def test_lora_model_with_random_weights_differs_from_base(self): + """LoRA model with random (non-zero) weights must produce different logits than base.""" + try: + from peft import LoraConfig, get_peft_model + except ImportError: + pytest.skip("peft library not installed") + + model, cfg = make_tiny_llama() + input_ids = torch.randint(0, VOCAB_SIZE, (1, SEQ_LEN)) + + # Get base model logits + with torch.no_grad(): + base_logits = model(input_ids=input_ids).logits + + # Wrap with LoRA and initialize with non-zero weights + lora_config = LoraConfig( + r=4, + lora_alpha=8, + target_modules=["q_proj", "v_proj"], + lora_dropout=0.0, + bias="none", + task_type="CAUSAL_LM", + ) + lora_model = get_peft_model(model, lora_config) + + # Initialize LoRA B matrices with non-zero values (default is zeros) + for name, param in lora_model.named_parameters(): + if "lora_B" in name: + torch.nn.init.normal_(param, mean=0.0, std=0.1) + + with torch.no_grad(): + lora_logits = lora_model(input_ids=input_ids).logits + + max_diff = (base_logits - lora_logits).abs().max().item() + assert max_diff > 1e-6, ( + f"LoRA model with non-zero B weights must produce different logits than base. max_diff={max_diff:.2e}" + ) + + def test_lora_model_with_zero_b_weights_matches_base(self): + """LoRA model with zero B weights (default init) must produce same logits as base.""" + try: + from peft import LoraConfig, get_peft_model + except ImportError: + pytest.skip("peft library not installed") + + model, cfg = make_tiny_llama() + input_ids = torch.randint(0, VOCAB_SIZE, (1, SEQ_LEN)) + + # Get base model logits + with torch.no_grad(): + base_logits = model(input_ids=input_ids).logits + + # Wrap with LoRA (default: B=0, so output is same as base) + lora_config = LoraConfig( + r=4, + lora_alpha=8, + target_modules=["q_proj", "v_proj"], + lora_dropout=0.0, + bias="none", + task_type="CAUSAL_LM", + ) + lora_model = get_peft_model(model, lora_config) + + with torch.no_grad(): + lora_logits = lora_model(input_ids=input_ids).logits + + max_diff = (base_logits - lora_logits).abs().max().item() + assert max_diff < 1e-5, f"LoRA model with zero B weights must match base model. max_diff={max_diff:.2e}" + + def test_lora_trainable_params_are_subset_of_all_params(self): + """LoRA trainable parameters must be a subset of all parameters.""" + try: + from peft import LoraConfig, get_peft_model + except ImportError: + pytest.skip("peft library not installed") + + model, cfg = make_tiny_llama() + lora_config = LoraConfig( + r=4, + lora_alpha=8, + target_modules=["q_proj", "v_proj"], + lora_dropout=0.0, + bias="none", + task_type="CAUSAL_LM", + ) + lora_model = get_peft_model(model, lora_config) + + trainable_params = sum(p.numel() for p in lora_model.parameters() if p.requires_grad) + total_params = sum(p.numel() for p in lora_model.parameters()) + assert trainable_params < total_params, ( + f"LoRA trainable params ({trainable_params}) must be less than total ({total_params})" + ) + + +# --------------------------------------------------------------------------- +# Tests: AdapterWeightsToInputsTransform ONNX graph (GAP G) +# --------------------------------------------------------------------------- + + +class TestAdapterWeightsToInputsTransformStructure: + """AdapterWeightsToInputsTransform must have correct structure.""" + + def test_adapter_weights_transform_importable(self): + from QEfficient.peft.onnx_transforms import AdapterWeightsToInputsTransform + + assert AdapterWeightsToInputsTransform is not None + + def test_adapter_weights_transform_has_apply_method(self): + from QEfficient.peft.onnx_transforms import AdapterWeightsToInputsTransform + + assert hasattr(AdapterWeightsToInputsTransform, "apply") + assert callable(AdapterWeightsToInputsTransform.apply) + + def test_adapter_weights_transform_apply_is_classmethod(self): + import inspect + + from QEfficient.peft.onnx_transforms import AdapterWeightsToInputsTransform + + assert isinstance( + inspect.getattr_static(AdapterWeightsToInputsTransform, "apply"), + classmethod, + ), "AdapterWeightsToInputsTransform.apply must be a classmethod" + + def test_adapter_weights_transform_in_peft_onnx_transforms(self): + """AdapterWeightsToInputsTransform (from base or peft) must be in QEffAutoPeftModelForCausalLM._onnx_transforms.""" + from QEfficient.peft.auto import QEffAutoPeftModelForCausalLM + + # AdapterWeightsToInputsTransform may be in base.onnx_transforms or peft.onnx_transforms + transform_names = [t.__name__ for t in QEffAutoPeftModelForCausalLM._onnx_transforms] + assert "AdapterWeightsToInputsTransform" in transform_names, ( + f"AdapterWeightsToInputsTransform not in QEffAutoPeftModelForCausalLM._onnx_transforms. " + f"Found: {transform_names}" + ) + + def test_peft_onnx_transforms_list_not_empty(self): + """QEffAutoPeftModelForCausalLM._onnx_transforms must not be empty.""" + from QEfficient.peft.auto import QEffAutoPeftModelForCausalLM + + assert len(QEffAutoPeftModelForCausalLM._onnx_transforms) > 0 + + def test_peft_pytorch_transforms_list_not_empty(self): + """QEffAutoPeftModelForCausalLM._pytorch_transforms must not be empty.""" + from QEfficient.peft.auto import QEffAutoPeftModelForCausalLM + + assert len(QEffAutoPeftModelForCausalLM._pytorch_transforms) > 0 + + def test_peft_model_has_export_method(self): + """QEffAutoPeftModelForCausalLM must have an export() method.""" + from QEfficient.peft.auto import QEffAutoPeftModelForCausalLM + + assert hasattr(QEffAutoPeftModelForCausalLM, "export") + assert callable(QEffAutoPeftModelForCausalLM.export) + + def test_peft_model_has_compile_method(self): + """QEffAutoPeftModelForCausalLM must have a compile() method.""" + from QEfficient.peft.auto import QEffAutoPeftModelForCausalLM + + assert hasattr(QEffAutoPeftModelForCausalLM, "compile") + assert callable(QEffAutoPeftModelForCausalLM.compile) diff --git a/tests/unit_test/transforms/test_quantization_transforms.py b/tests/unit_test/transforms/test_quantization_transforms.py new file mode 100644 index 000000000..b7fa03c1d --- /dev/null +++ b/tests/unit_test/transforms/test_quantization_transforms.py @@ -0,0 +1,357 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +# SPDX-License-Identifier: BSD-3-Clause +# +# ----------------------------------------------------------------------------- +""" +Tests for quantization transforms and quantizer auto-detection in QEfficient. + +Tests verify: + - AwqToMatmulNbitsTransform: importable, has _match_class, has mutate method + - GPTQToMatmulNbitsTransform: importable, has _match_class, has mutate method + - FP8DeQuantLinearToLinearTransform: importable, has _match_class, has mutate method + - Mxfp4GptOssExpertDequantizeTransform: importable, has _match_class, has mutate method + - QEFF_AUTO_QUANTIZATION_CONFIG_MAPPING: contains all expected quantization types + - QEFF_AUTO_QUANTIZER_MAPPING: contains all expected quantizer types + - with_replaced_quantizers: replaces and restores transformers quantizers correctly + - QEFFAutoModelForCausalLM._pytorch_transforms includes quantization transforms + +All tests run on CPU only, no quantized model downloads required. +""" + + +# --------------------------------------------------------------------------- +# Tests: Quantization Transform Importability and Structure +# --------------------------------------------------------------------------- + + +class TestQuantizationTransformImportability: + """All quantization transforms must be importable and have correct structure.""" + + def test_awq_transform_importable(self): + from QEfficient.transformers.quantizers.quant_transforms import AwqToMatmulNbitsTransform + + assert AwqToMatmulNbitsTransform is not None + + def test_gptq_transform_importable(self): + from QEfficient.transformers.quantizers.quant_transforms import GPTQToMatmulNbitsTransform + + assert GPTQToMatmulNbitsTransform is not None + + def test_fp8_transform_importable(self): + from QEfficient.transformers.quantizers.quant_transforms import FP8DeQuantLinearToLinearTransform + + assert FP8DeQuantLinearToLinearTransform is not None + + def test_mxfp4_transform_importable(self): + from QEfficient.transformers.quantizers.quant_transforms import Mxfp4GptOssExpertDequantizeTransform + + assert Mxfp4GptOssExpertDequantizeTransform is not None + + def test_awq_transform_has_match_class(self): + from QEfficient.transformers.quantizers.quant_transforms import AwqToMatmulNbitsTransform + + assert hasattr(AwqToMatmulNbitsTransform, "_match_class") + + def test_gptq_transform_has_match_class(self): + from QEfficient.transformers.quantizers.quant_transforms import GPTQToMatmulNbitsTransform + + assert hasattr(GPTQToMatmulNbitsTransform, "_match_class") + + def test_fp8_transform_has_match_class(self): + from QEfficient.transformers.quantizers.quant_transforms import FP8DeQuantLinearToLinearTransform + + assert hasattr(FP8DeQuantLinearToLinearTransform, "_match_class") + + def test_mxfp4_transform_has_match_class(self): + from QEfficient.transformers.quantizers.quant_transforms import Mxfp4GptOssExpertDequantizeTransform + + assert hasattr(Mxfp4GptOssExpertDequantizeTransform, "_match_class") + + def test_awq_match_class_is_wqlinear_gemm(self): + from QEfficient.transformers.quantizers.awq import WQLinear_GEMM + from QEfficient.transformers.quantizers.quant_transforms import AwqToMatmulNbitsTransform + + assert AwqToMatmulNbitsTransform._match_class is WQLinear_GEMM + + def test_gptq_match_class_is_quantlinear_gptq(self): + from QEfficient.transformers.quantizers.gptq import QuantLinearGPTQ + from QEfficient.transformers.quantizers.quant_transforms import GPTQToMatmulNbitsTransform + + assert GPTQToMatmulNbitsTransform._match_class is QuantLinearGPTQ + + def test_fp8_match_class_is_fp8_dequant_linear(self): + from QEfficient.transformers.quantizers.quant_transforms import FP8DeQuantLinearToLinearTransform + from QEfficient.transformers.quantizers.quantizer_compressed_tensors import FP8DeQuantLinear + + assert FP8DeQuantLinearToLinearTransform._match_class is FP8DeQuantLinear + + def test_all_transforms_have_mutate_classmethod(self): + from QEfficient.transformers.quantizers.quant_transforms import ( + AwqToMatmulNbitsTransform, + FP8DeQuantLinearToLinearTransform, + GPTQToMatmulNbitsTransform, + Mxfp4GptOssExpertDequantizeTransform, + ) + + for cls in [ + AwqToMatmulNbitsTransform, + GPTQToMatmulNbitsTransform, + FP8DeQuantLinearToLinearTransform, + Mxfp4GptOssExpertDequantizeTransform, + ]: + assert hasattr(cls, "mutate"), f"{cls.__name__} missing mutate method" + assert callable(cls.mutate), f"{cls.__name__}.mutate is not callable" + + def test_all_transforms_are_subclasses_of_module_mutator(self): + from QEfficient.base.pytorch_transforms import ModuleMutatorTransform + from QEfficient.transformers.quantizers.quant_transforms import ( + AwqToMatmulNbitsTransform, + FP8DeQuantLinearToLinearTransform, + GPTQToMatmulNbitsTransform, + Mxfp4GptOssExpertDequantizeTransform, + ) + + for cls in [ + AwqToMatmulNbitsTransform, + GPTQToMatmulNbitsTransform, + FP8DeQuantLinearToLinearTransform, + Mxfp4GptOssExpertDequantizeTransform, + ]: + assert issubclass(cls, ModuleMutatorTransform), ( + f"{cls.__name__} must be a subclass of ModuleMutatorTransform" + ) + + +# --------------------------------------------------------------------------- +# Tests: QEFF_AUTO_QUANTIZATION_CONFIG_MAPPING +# --------------------------------------------------------------------------- + + +class TestQEffAutoQuantizationConfigMapping: + """QEFF_AUTO_QUANTIZATION_CONFIG_MAPPING must contain all expected quantization types.""" + + def test_mapping_exists_and_is_dict(self): + from QEfficient.transformers.quantizers.auto import QEFF_AUTO_QUANTIZATION_CONFIG_MAPPING + + assert isinstance(QEFF_AUTO_QUANTIZATION_CONFIG_MAPPING, dict) + + def test_contains_awq(self): + from QEfficient.transformers.quantizers.auto import QEFF_AUTO_QUANTIZATION_CONFIG_MAPPING + + assert "awq" in QEFF_AUTO_QUANTIZATION_CONFIG_MAPPING + + def test_contains_gptq(self): + from QEfficient.transformers.quantizers.auto import QEFF_AUTO_QUANTIZATION_CONFIG_MAPPING + + assert "gptq" in QEFF_AUTO_QUANTIZATION_CONFIG_MAPPING + + def test_contains_compressed_tensors(self): + from QEfficient.transformers.quantizers.auto import QEFF_AUTO_QUANTIZATION_CONFIG_MAPPING + + assert "compressed-tensors" in QEFF_AUTO_QUANTIZATION_CONFIG_MAPPING + + def test_awq_config_is_qeff_awq_config(self): + from QEfficient.transformers.quantizers.auto import QEFF_AUTO_QUANTIZATION_CONFIG_MAPPING + from QEfficient.transformers.quantizers.quantizer_awq import QEffAwqConfig + + assert QEFF_AUTO_QUANTIZATION_CONFIG_MAPPING["awq"] is QEffAwqConfig + + def test_gptq_config_is_qeff_gptq_config(self): + from QEfficient.transformers.quantizers.auto import QEFF_AUTO_QUANTIZATION_CONFIG_MAPPING + from QEfficient.transformers.quantizers.quantizer_gptq import QEffGPTQConfig + + assert QEFF_AUTO_QUANTIZATION_CONFIG_MAPPING["gptq"] is QEffGPTQConfig + + def test_all_values_are_classes(self): + from QEfficient.transformers.quantizers.auto import QEFF_AUTO_QUANTIZATION_CONFIG_MAPPING + + for key, val in QEFF_AUTO_QUANTIZATION_CONFIG_MAPPING.items(): + assert isinstance(val, type), f"Expected class for key '{key}', got {type(val)}" + + +# --------------------------------------------------------------------------- +# Tests: QEFF_AUTO_QUANTIZER_MAPPING +# --------------------------------------------------------------------------- + + +class TestQEffAutoQuantizerMapping: + """QEFF_AUTO_QUANTIZER_MAPPING must contain all expected quantizer types.""" + + def test_mapping_exists_and_is_dict(self): + from QEfficient.transformers.quantizers.auto import QEFF_AUTO_QUANTIZER_MAPPING + + assert isinstance(QEFF_AUTO_QUANTIZER_MAPPING, dict) + + def test_contains_awq(self): + from QEfficient.transformers.quantizers.auto import QEFF_AUTO_QUANTIZER_MAPPING + + assert "awq" in QEFF_AUTO_QUANTIZER_MAPPING + + def test_contains_gptq(self): + from QEfficient.transformers.quantizers.auto import QEFF_AUTO_QUANTIZER_MAPPING + + assert "gptq" in QEFF_AUTO_QUANTIZER_MAPPING + + def test_awq_quantizer_is_qeff_awq_quantizer(self): + from QEfficient.transformers.quantizers.auto import QEFF_AUTO_QUANTIZER_MAPPING + from QEfficient.transformers.quantizers.quantizer_awq import QEffAwqQuantizer + + assert QEFF_AUTO_QUANTIZER_MAPPING["awq"] is QEffAwqQuantizer + + def test_gptq_quantizer_is_qeff_gptq_quantizer(self): + from QEfficient.transformers.quantizers.auto import QEFF_AUTO_QUANTIZER_MAPPING + from QEfficient.transformers.quantizers.quantizer_gptq import QEffGPTQQuantizer + + assert QEFF_AUTO_QUANTIZER_MAPPING["gptq"] is QEffGPTQQuantizer + + def test_all_values_are_classes(self): + from QEfficient.transformers.quantizers.auto import QEFF_AUTO_QUANTIZER_MAPPING + + for key, val in QEFF_AUTO_QUANTIZER_MAPPING.items(): + assert isinstance(val, type), f"Expected class for key '{key}', got {type(val)}" + + +# --------------------------------------------------------------------------- +# Tests: with_replaced_quantizers decorator +# --------------------------------------------------------------------------- + + +class TestWithReplacedQuantizers: + """with_replaced_quantizers must replace and restore transformers quantizers correctly.""" + + def test_with_replaced_quantizers_is_callable(self): + from QEfficient.transformers.quantizers.auto import with_replaced_quantizers + + assert callable(with_replaced_quantizers) + + def test_with_replaced_quantizers_wraps_function(self): + """Inside the wrapper, AUTO_QUANTIZATION_CONFIG_MAPPING must have QEff configs.""" + from transformers.quantizers.auto import AUTO_QUANTIZATION_CONFIG_MAPPING + + from QEfficient.transformers.quantizers.auto import ( + QEFF_AUTO_QUANTIZATION_CONFIG_MAPPING, + with_replaced_quantizers, + ) + + call_log = [] + + @with_replaced_quantizers + def dummy_func(): + for k, v in QEFF_AUTO_QUANTIZATION_CONFIG_MAPPING.items(): + assert AUTO_QUANTIZATION_CONFIG_MAPPING.get(k) is v, ( + f"Key '{k}' not replaced: expected {v}, got {AUTO_QUANTIZATION_CONFIG_MAPPING.get(k)}" + ) + call_log.append("called") + return "result" + + result = dummy_func() + assert result == "result" + assert call_log == ["called"] + + def test_with_replaced_quantizers_restores_after_call(self): + """After the wrapped function returns, original quantizers must be restored.""" + from transformers.quantizers.auto import AUTO_QUANTIZATION_CONFIG_MAPPING + + from QEfficient.transformers.quantizers.auto import with_replaced_quantizers + + # Capture original values before wrapping + original_awq = AUTO_QUANTIZATION_CONFIG_MAPPING.get("awq") + + @with_replaced_quantizers + def dummy_func(): + pass + + dummy_func() + + # After call, original must be restored + assert AUTO_QUANTIZATION_CONFIG_MAPPING.get("awq") is original_awq, ( + "with_replaced_quantizers must restore original 'awq' config after call" + ) + + def test_with_replaced_quantizers_preserves_return_value(self): + from QEfficient.transformers.quantizers.auto import with_replaced_quantizers + + @with_replaced_quantizers + def func_with_return(): + return {"key": "value", "num": 42} + + result = func_with_return() + assert result == {"key": "value", "num": 42} + + +# --------------------------------------------------------------------------- +# Tests: QEFFAutoModelForCausalLM quantization transform integration +# --------------------------------------------------------------------------- + + +class TestQEFFAutoModelQuantizationIntegration: + """QEFFAutoModelForCausalLM must include quantization transforms in its pipeline.""" + + def test_pytorch_transforms_include_awq_transform(self): + from QEfficient.transformers.models.modeling_auto import QEFFAutoModelForCausalLM + from QEfficient.transformers.quantizers.quant_transforms import AwqToMatmulNbitsTransform + + assert AwqToMatmulNbitsTransform in QEFFAutoModelForCausalLM._pytorch_transforms, ( + "AwqToMatmulNbitsTransform not in QEFFAutoModelForCausalLM._pytorch_transforms" + ) + + def test_pytorch_transforms_include_gptq_transform(self): + from QEfficient.transformers.models.modeling_auto import QEFFAutoModelForCausalLM + from QEfficient.transformers.quantizers.quant_transforms import GPTQToMatmulNbitsTransform + + assert GPTQToMatmulNbitsTransform in QEFFAutoModelForCausalLM._pytorch_transforms, ( + "GPTQToMatmulNbitsTransform not in QEFFAutoModelForCausalLM._pytorch_transforms" + ) + + def test_pytorch_transforms_include_fp8_transform(self): + from QEfficient.transformers.models.modeling_auto import QEFFAutoModelForCausalLM + from QEfficient.transformers.quantizers.quant_transforms import FP8DeQuantLinearToLinearTransform + + assert FP8DeQuantLinearToLinearTransform in QEFFAutoModelForCausalLM._pytorch_transforms, ( + "FP8DeQuantLinearToLinearTransform not in QEFFAutoModelForCausalLM._pytorch_transforms" + ) + + def test_quantization_transforms_come_before_kv_cache_transform(self): + """Quantization transforms must be applied before KVCacheTransform.""" + from QEfficient.transformers.models.modeling_auto import QEFFAutoModelForCausalLM + from QEfficient.transformers.models.pytorch_transforms import KVCacheTransform + from QEfficient.transformers.quantizers.quant_transforms import AwqToMatmulNbitsTransform + + transforms = QEFFAutoModelForCausalLM._pytorch_transforms + awq_idx = next((i for i, t in enumerate(transforms) if t is AwqToMatmulNbitsTransform), None) + kv_idx = next((i for i, t in enumerate(transforms) if t is KVCacheTransform), None) + assert awq_idx is not None, "AwqToMatmulNbitsTransform not found in _pytorch_transforms" + assert kv_idx is not None, "KVCacheTransform not found in _pytorch_transforms" + assert awq_idx < kv_idx, ( + f"AwqToMatmulNbitsTransform (idx={awq_idx}) must come before KVCacheTransform (idx={kv_idx})" + ) + + def test_non_quantized_model_not_affected_by_quant_transforms(self): + """Applying quantization transforms to a non-quantized model must not change it.""" + import torch + from transformers import GPT2Config, GPT2LMHeadModel + + from QEfficient.transformers.quantizers.quant_transforms import ( + AwqToMatmulNbitsTransform, + GPTQToMatmulNbitsTransform, + ) + + cfg = GPT2Config(n_layer=1, n_head=2, n_embd=64, vocab_size=500, n_positions=32, n_ctx=32) + model = GPT2LMHeadModel(cfg).eval() + + # Apply AWQ transform - should not change a non-quantized model + model_awq, applied_awq = AwqToMatmulNbitsTransform.apply(model) + assert not applied_awq, "AwqToMatmulNbitsTransform must not apply to non-quantized model" + + # Apply GPTQ transform - should not change a non-quantized model + model_gptq, applied_gptq = GPTQToMatmulNbitsTransform.apply(model) + assert not applied_gptq, "GPTQToMatmulNbitsTransform must not apply to non-quantized model" + + # Model output must be unchanged + input_ids = torch.randint(0, 500, (1, 8)) + with torch.no_grad(): + original_logits = model(input_ids=input_ids).logits + awq_logits = model_awq(input_ids=input_ids).logits + assert torch.allclose(original_logits, awq_logits), "AWQ transform must not change non-quantized model output" diff --git a/tests/unit_test/transforms/test_speculative_decoding.py b/tests/unit_test/transforms/test_speculative_decoding.py new file mode 100644 index 000000000..cdffb7c46 --- /dev/null +++ b/tests/unit_test/transforms/test_speculative_decoding.py @@ -0,0 +1,581 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +# SPDX-License-Identifier: BSD-3-Clause +# +# ----------------------------------------------------------------------------- +""" +Tests for Speculative Decoding (SpDTransform) in QEfficient. + +Tests verify: + - SpDTransform.apply() with speculative_model_type="target" attaches tlm_forward + - SpDTransform._module_mapping contains expected model classes + - SpDTransform raises ValueError for invalid speculative_model_type + - SpDTransform raises NotImplementedError for unsupported model class + - QEFFAutoModelForCausalLM has check_and_get_num_speculative_tokens method + - QEFFAutoModelForCausalLM has build_prefill_specialization / build_decode_specialization + - is_tlm flag is set correctly on the wrapper + +All tests run on CPU only. +""" + +import pytest +import torch +from transformers import LlamaConfig, LlamaForCausalLM + +from QEfficient.transformers.models.pytorch_transforms import KVCacheTransform, SpDTransform + +VOCAB_SIZE = 500 +SEQ_LEN = 8 +CTX_LEN = 32 + + +def make_tiny_llama(): + cfg = LlamaConfig( + num_hidden_layers=2, + num_attention_heads=2, + num_key_value_heads=2, + hidden_size=64, + intermediate_size=128, + vocab_size=VOCAB_SIZE, + max_position_embeddings=CTX_LEN, + ) + return LlamaForCausalLM(cfg).eval(), cfg + + +def make_kv_transformed_llama(): + model, cfg = make_tiny_llama() + transformed, _ = KVCacheTransform.apply(model) + return transformed, cfg + + +# --------------------------------------------------------------------------- +# Tests: SpDTransform module mapping and structure +# --------------------------------------------------------------------------- + + +@pytest.mark.transforms +class TestSpDTransformStructure: + """SpDTransform must have correct class-level structure.""" + + def test_spd_transform_importable(self): + from QEfficient.transformers.models.pytorch_transforms import SpDTransform + + assert SpDTransform is not None + + def test_module_mapping_is_set(self): + assert hasattr(SpDTransform, "_module_mapping") + assert len(SpDTransform._module_mapping) > 0 + + def test_module_mapping_contains_llama(self): + from QEfficient.transformers.models.llama.modeling_llama import QEffLlamaForCausalLM + + assert QEffLlamaForCausalLM in SpDTransform._module_mapping + + def test_module_mapping_contains_qwen2(self): + from QEfficient.transformers.models.qwen2.modeling_qwen2 import QEffQwen2ForCausalLM + + assert QEffQwen2ForCausalLM in SpDTransform._module_mapping + + def test_apply_classmethod_exists(self): + assert hasattr(SpDTransform, "apply") + assert callable(SpDTransform.apply) + + +# --------------------------------------------------------------------------- +# Tests: SpDTransform no-op paths (already tested in test_transform_accuracy.py, +# but included here for completeness) +# --------------------------------------------------------------------------- + + +@pytest.mark.transforms +class TestSpDTransformNoOpPaths: + """SpDTransform must not apply when qaic_config is None or missing key.""" + + def test_no_transform_when_qaic_config_is_none(self): + model, _ = make_kv_transformed_llama() + _, applied = SpDTransform.apply(model, qaic_config=None) + assert not applied + + def test_no_transform_when_speculative_model_type_missing(self): + model, _ = make_kv_transformed_llama() + _, applied = SpDTransform.apply(model, qaic_config={}) + assert not applied + + def test_invalid_speculative_model_type_raises_value_error(self): + model, _ = make_kv_transformed_llama() + with pytest.raises(ValueError): + SpDTransform.apply(model, qaic_config={"speculative_model_type": "invalid_xyz_abc"}) + + def test_unsupported_model_class_raises_not_implemented(self): + import torch.nn as nn + + class UnsupportedModel(nn.Module): + def forward(self, x): + return x + + with pytest.raises(NotImplementedError): + SpDTransform.apply( + UnsupportedModel(), + qaic_config={"speculative_model_type": "target"}, + ) + + +# --------------------------------------------------------------------------- +# Tests: SpDTransform actual apply (TLM path) +# --------------------------------------------------------------------------- + + +@pytest.mark.transforms +class TestSpDTransformTLMApply: + """SpDTransform with speculative_model_type='target' must attach tlm_forward.""" + + def test_spd_transform_applies_to_llama_with_target_type(self): + """SpDTransform must apply successfully to QEffLlamaForCausalLM with target type.""" + model, _ = make_kv_transformed_llama() + transformed, applied = SpDTransform.apply(model, qaic_config={"speculative_model_type": "target"}) + assert applied, "SpDTransform must apply when speculative_model_type='target'" + + def test_spd_transform_forward_is_replaced(self): + """After SpDTransform, model.forward must be replaced with a SpD-specific forward.""" + model, _ = make_kv_transformed_llama() + original_forward = model.forward + transformed, applied = SpDTransform.apply(model, qaic_config={"speculative_model_type": "target"}) + assert applied + assert hasattr(transformed, "forward") + # The forward must have been replaced (different from original) + assert transformed.forward is not original_forward, ( + "SpDTransform must replace model.forward with a SpD-specific forward" + ) + + def test_spd_transform_returns_model_instance(self): + """SpDTransform must return the same model instance (in-place modification).""" + model, _ = make_kv_transformed_llama() + transformed, applied = SpDTransform.apply(model, qaic_config={"speculative_model_type": "target"}) + assert applied + assert transformed is model, "SpDTransform must modify model in-place" + + def test_spd_transformed_model_is_still_eval_mode(self): + """SpDTransform must not change the model's training mode.""" + model, _ = make_kv_transformed_llama() + assert not model.training + transformed, _ = SpDTransform.apply(model, qaic_config={"speculative_model_type": "target"}) + assert not transformed.training, "SpDTransform must not change model to training mode" + + def test_spd_transform_model_still_has_parameters(self): + """After SpDTransform, model must still have its parameters.""" + model, _ = make_kv_transformed_llama() + param_count_before = sum(p.numel() for p in model.parameters()) + transformed, _ = SpDTransform.apply(model, qaic_config={"speculative_model_type": "target"}) + param_count_after = sum(p.numel() for p in transformed.parameters()) + assert param_count_before == param_count_after, ( + f"SpDTransform changed parameter count: {param_count_before} → {param_count_after}" + ) + + +# --------------------------------------------------------------------------- +# Tests: QEFFAutoModelForCausalLM SpD-related methods +# --------------------------------------------------------------------------- + + +@pytest.mark.transforms +class TestQEFFAutoModelSpDMethods: + """QEFFAutoModelForCausalLM must have SpD-related methods.""" + + def test_has_check_and_get_num_speculative_tokens(self): + from QEfficient.transformers.models.modeling_auto import QEFFAutoModelForCausalLM + + assert hasattr(QEFFAutoModelForCausalLM, "check_and_get_num_speculative_tokens") + assert callable(QEFFAutoModelForCausalLM.check_and_get_num_speculative_tokens) + + def test_has_build_prefill_specialization(self): + from QEfficient.transformers.models.modeling_auto import QEFFAutoModelForCausalLM + + assert hasattr(QEFFAutoModelForCausalLM, "build_prefill_specialization") + assert callable(QEFFAutoModelForCausalLM.build_prefill_specialization) + + def test_has_build_decode_specialization(self): + from QEfficient.transformers.models.modeling_auto import QEFFAutoModelForCausalLM + + assert hasattr(QEFFAutoModelForCausalLM, "build_decode_specialization") + assert callable(QEFFAutoModelForCausalLM.build_decode_specialization) + + def test_has_is_tlm_property(self): + """QEFFAutoModelForCausalLM instances must expose is_tlm.""" + from transformers import GPT2Config, GPT2LMHeadModel + + from QEfficient.transformers.models.modeling_auto import QEFFAutoModelForCausalLM + + cfg = GPT2Config(n_layer=1, n_head=2, n_embd=64, vocab_size=500, n_positions=32, n_ctx=32) + model = GPT2LMHeadModel(cfg) + qeff = QEFFAutoModelForCausalLM(model) + assert hasattr(qeff, "is_tlm"), "QEFFAutoModelForCausalLM instance must have is_tlm attribute" + + def test_is_tlm_false_by_default(self): + """Without SpD config, is_tlm must be False.""" + from transformers import GPT2Config, GPT2LMHeadModel + + from QEfficient.transformers.models.modeling_auto import QEFFAutoModelForCausalLM + + cfg = GPT2Config(n_layer=1, n_head=2, n_embd=64, vocab_size=500, n_positions=32, n_ctx=32) + model = GPT2LMHeadModel(cfg) + qeff = QEFFAutoModelForCausalLM(model) + assert qeff.is_tlm is False, "is_tlm must be False when no SpD config is provided" + + def test_check_and_get_num_speculative_tokens_returns_none_for_non_tlm(self): + """For a non-TLM model, check_and_get_num_speculative_tokens must not raise.""" + from transformers import GPT2Config, GPT2LMHeadModel + + from QEfficient.transformers.models.modeling_auto import QEFFAutoModelForCausalLM + + cfg = GPT2Config(n_layer=1, n_head=2, n_embd=64, vocab_size=500, n_positions=32, n_ctx=32) + model = GPT2LMHeadModel(cfg) + qeff = QEFFAutoModelForCausalLM(model) + # For non-TLM, is_tlm=False; method accepts num_speculative_tokens and prefill_seq_len + result = qeff.check_and_get_num_speculative_tokens(num_speculative_tokens=None, prefill_seq_len=1) + assert result is None, f"check_and_get_num_speculative_tokens must return None for non-TLM, got {result}" + + def test_build_prefill_specialization_returns_dict(self): + """build_prefill_specialization must return a dict-like object.""" + from transformers import GPT2Config, GPT2LMHeadModel + + from QEfficient.transformers.models.modeling_auto import QEFFAutoModelForCausalLM + + cfg = GPT2Config(n_layer=1, n_head=2, n_embd=64, vocab_size=500, n_positions=32, n_ctx=32) + model = GPT2LMHeadModel(cfg) + qeff = QEFFAutoModelForCausalLM(model) + result = qeff.build_prefill_specialization(prefill_seq_len=8, ctx_len=32, batch_size=1, full_batch_size=None) + assert isinstance(result, dict), f"build_prefill_specialization must return dict, got {type(result)}" + + def test_build_decode_specialization_returns_dict(self): + """build_decode_specialization must return a dict-like object.""" + from transformers import GPT2Config, GPT2LMHeadModel + + from QEfficient.transformers.models.modeling_auto import QEFFAutoModelForCausalLM + + cfg = GPT2Config(n_layer=1, n_head=2, n_embd=64, vocab_size=500, n_positions=32, n_ctx=32) + model = GPT2LMHeadModel(cfg) + qeff = QEFFAutoModelForCausalLM(model) + result = qeff.build_decode_specialization(ctx_len=32, batch_size=1, full_batch_size=None) + assert isinstance(result, dict), f"build_decode_specialization must return dict, got {type(result)}" + + +# --------------------------------------------------------------------------- +# Tests: TLM forward execution +# --------------------------------------------------------------------------- + + +@pytest.mark.transforms +@pytest.mark.accuracy +class TestTLMForwardExecution: + """After SpDTransform, the replaced tlm_forward must produce correct outputs.""" + + def _make_tlm_inputs(self, batch=1, num_spec_tokens=3, n_layers=2, n_kv=2, head_dim=32): + """Create inputs for TLM forward with pre-allocated zero KV cache.""" + seq_len = num_spec_tokens + 1 + input_ids = torch.randint(0, VOCAB_SIZE, (batch, seq_len)) + position_ids = torch.arange(seq_len).unsqueeze(0).expand(batch, -1) + past_key_values = tuple( + ( + torch.zeros(batch, n_kv, CTX_LEN, head_dim, dtype=torch.float32), + torch.zeros(batch, n_kv, CTX_LEN, head_dim, dtype=torch.float32), + ) + for _ in range(n_layers) + ) + return input_ids, position_ids, past_key_values + + def test_tlm_forward_returns_logits(self): + """tlm_forward must return an object with logits attribute.""" + model, cfg = make_kv_transformed_llama() + transformed, applied = SpDTransform.apply(model, qaic_config={"speculative_model_type": "target"}) + assert applied + + batch, num_spec_tokens = 1, 3 + # n_kv=2, head_dim=64//2=32 for tiny llama + # num_logits_to_keep must be a tensor (as expected by spd_transform_forward) + input_ids, position_ids, past_kv = self._make_tlm_inputs( + batch, num_spec_tokens, n_layers=2, n_kv=2, head_dim=32 + ) + num_logits_tensor = torch.tensor([num_spec_tokens], dtype=torch.int64) + + with torch.no_grad(): + output = transformed( + input_ids=input_ids, + position_ids=position_ids, + past_key_values=past_kv, + num_logits_to_keep=num_logits_tensor, + ) + assert hasattr(output, "logits"), "TLM forward must return output with logits" + + def test_tlm_forward_logits_are_finite(self): + """tlm_forward logits must be finite (no NaN/Inf).""" + model, cfg = make_kv_transformed_llama() + transformed, applied = SpDTransform.apply(model, qaic_config={"speculative_model_type": "target"}) + assert applied + + batch, num_spec_tokens = 1, 3 + input_ids, position_ids, past_kv = self._make_tlm_inputs( + batch, num_spec_tokens, n_layers=2, n_kv=2, head_dim=32 + ) + num_logits_tensor = torch.tensor([num_spec_tokens], dtype=torch.int64) + + with torch.no_grad(): + output = transformed( + input_ids=input_ids, + position_ids=position_ids, + past_key_values=past_kv, + num_logits_to_keep=num_logits_tensor, + ) + assert torch.isfinite(output.logits).all(), "TLM logits must be finite" + + def test_tlm_forward_logits_shape_is_batch_x_kept_x_vocab(self): + """tlm_forward logits shape must be [batch, num_logits_to_keep, vocab_size]. + num_logits_to_keep is a 1D tensor of shape [1] containing the count, + so the output has shape[1] == num_logits_to_keep.shape[0] == 1.""" + model, cfg = make_kv_transformed_llama() + transformed, applied = SpDTransform.apply(model, qaic_config={"speculative_model_type": "target"}) + assert applied + + batch, num_spec_tokens = 1, 3 + input_ids, position_ids, past_kv = self._make_tlm_inputs( + batch, num_spec_tokens, n_layers=2, n_kv=2, head_dim=32 + ) + # num_logits_to_keep is a 1D tensor; shape[0] determines how many logits are kept + num_logits_tensor = torch.tensor([num_spec_tokens], dtype=torch.int64) + + with torch.no_grad(): + output = transformed( + input_ids=input_ids, + position_ids=position_ids, + past_key_values=past_kv, + num_logits_to_keep=num_logits_tensor, + ) + # batch dimension must match + assert output.logits.shape[0] == batch + # vocab dimension must match + assert output.logits.shape[-1] == VOCAB_SIZE + # logits must be 3D: [batch, seq, vocab] + assert output.logits.ndim == 3 + + def test_tlm_forward_greedy_tokens_in_valid_range(self): + """Greedy tokens from tlm_forward must be in [0, vocab_size).""" + model, cfg = make_kv_transformed_llama() + transformed, applied = SpDTransform.apply(model, qaic_config={"speculative_model_type": "target"}) + assert applied + + batch, num_spec_tokens = 1, 3 + input_ids, position_ids, past_kv = self._make_tlm_inputs( + batch, num_spec_tokens, n_layers=2, n_kv=2, head_dim=32 + ) + num_logits_tensor = torch.tensor([num_spec_tokens], dtype=torch.int64) + + with torch.no_grad(): + output = transformed( + input_ids=input_ids, + position_ids=position_ids, + past_key_values=past_kv, + num_logits_to_keep=num_logits_tensor, + ) + greedy_tokens = output.logits.argmax(dim=-1) + assert (greedy_tokens >= 0).all() + assert (greedy_tokens < VOCAB_SIZE).all() + + +# --------------------------------------------------------------------------- +# Tests: SpDTransform for Qwen2 +# --------------------------------------------------------------------------- + + +@pytest.mark.transforms +class TestSpDTransformQwen2: + """SpDTransform must apply correctly to Qwen2 models.""" + + def _make_kv_transformed_qwen2(self): + from transformers import Qwen2Config, Qwen2ForCausalLM + + from QEfficient.transformers.models.pytorch_transforms import KVCacheTransform + + cfg = Qwen2Config( + num_hidden_layers=2, + num_attention_heads=2, + num_key_value_heads=2, + hidden_size=64, + intermediate_size=128, + vocab_size=VOCAB_SIZE, + max_position_embeddings=CTX_LEN, + ) + model = Qwen2ForCausalLM(cfg).eval() + transformed, _ = KVCacheTransform.apply(model) + return transformed, cfg + + def test_spd_transform_applies_to_qwen2_with_target_type(self): + """SpDTransform must apply successfully to QEffQwen2ForCausalLM.""" + model, _ = self._make_kv_transformed_qwen2() + transformed, applied = SpDTransform.apply(model, qaic_config={"speculative_model_type": "target"}) + assert applied, "SpDTransform must apply to Qwen2 with target type" + + def test_spd_transform_qwen2_forward_is_replaced(self): + """After SpDTransform, Qwen2 model.forward must be replaced.""" + model, _ = self._make_kv_transformed_qwen2() + original_forward = model.forward + transformed, applied = SpDTransform.apply(model, qaic_config={"speculative_model_type": "target"}) + assert applied + assert transformed.forward is not original_forward + + def test_spd_transform_qwen2_produces_finite_logits(self): + """After SpDTransform, Qwen2 forward must produce finite logits.""" + + model, _ = self._make_kv_transformed_qwen2() + transformed, applied = SpDTransform.apply(model, qaic_config={"speculative_model_type": "target"}) + assert applied + + batch, num_spec_tokens = 1, 2 + seq_len = num_spec_tokens + 1 + input_ids = torch.randint(0, VOCAB_SIZE, (batch, seq_len)) + position_ids = torch.arange(seq_len).unsqueeze(0).expand(batch, -1) + # Use tuple-based KV cache (n_kv=2, head_dim=64//2=32) + past_kv = tuple( + ( + torch.zeros(batch, 2, CTX_LEN, 32, dtype=torch.float32), + torch.zeros(batch, 2, CTX_LEN, 32, dtype=torch.float32), + ) + for _ in range(2) + ) + num_logits_tensor = torch.tensor([num_spec_tokens], dtype=torch.int64) + + with torch.no_grad(): + output = transformed( + input_ids=input_ids, + position_ids=position_ids, + past_key_values=past_kv, + num_logits_to_keep=num_logits_tensor, + ) + assert torch.isfinite(output.logits).all() + + +# --------------------------------------------------------------------------- +# Tests: post_processing.py registry +# --------------------------------------------------------------------------- + + +@pytest.mark.transforms +class TestPostProcessingRegistry: + """post_processing.model_type_registry must contain expected model types.""" + + def test_model_type_registry_is_not_empty(self): + """model_type_registry must not be empty.""" + from QEfficient.transformers.post_processing import model_type_registry + + assert len(model_type_registry) > 0 + + def test_model_type_registry_contains_turbo(self): + """model_type_registry must contain 'turbo' (the SpD post-processing type).""" + from QEfficient.transformers.post_processing import model_type_registry + + assert "turbo" in model_type_registry + + def test_model_type_registry_keys_are_strings(self): + """All keys in model_type_registry must be strings.""" + from QEfficient.transformers.post_processing import model_type_registry + + for key in model_type_registry: + assert isinstance(key, str), f"Registry key must be string, got {type(key)}" + + def test_model_type_registry_values_are_callable(self): + """All values in model_type_registry must be callable.""" + from QEfficient.transformers.post_processing import model_type_registry + + for model_type, handler in model_type_registry.items(): + assert callable(handler), f"Handler for '{model_type}' must be callable" + + +# --------------------------------------------------------------------------- +# Tests: SpD ONNX structure (GAP I) +# --------------------------------------------------------------------------- + + +@pytest.mark.transforms +class TestSpDONNXStructure: + """SpD-related ONNX structure tests — verify num_logits_to_keep input and build_and_attach_mlp.""" + + def test_build_and_attach_mlp_importable(self): + """build_and_attach_mlp must be importable from post_processing.""" + from QEfficient.transformers.post_processing import build_and_attach_mlp + + assert build_and_attach_mlp is not None + + def test_build_and_attach_mlp_is_callable(self): + """build_and_attach_mlp must be callable.""" + from QEfficient.transformers.post_processing import build_and_attach_mlp + + assert callable(build_and_attach_mlp) + + def test_build_and_attach_mlp_accepts_model_parameter(self): + """build_and_attach_mlp must accept 'model' as first parameter.""" + import inspect + + from QEfficient.transformers.post_processing import build_and_attach_mlp + + sig = inspect.signature(build_and_attach_mlp) + assert "model" in sig.parameters + + def test_build_and_attach_mlp_accepts_speculative_model_type(self): + """build_and_attach_mlp must accept 'speculative_model_type' parameter.""" + import inspect + + from QEfficient.transformers.post_processing import build_and_attach_mlp + + sig = inspect.signature(build_and_attach_mlp) + assert "speculative_model_type" in sig.parameters + + def test_model_type_registry_has_turbo(self): + """model_type_registry must contain 'turbo' key.""" + from QEfficient.transformers.post_processing import model_type_registry + + assert "turbo" in model_type_registry + + def test_build_and_attach_turbo_importable(self): + """build_and_attach_turbo must be importable from spd.turbo.""" + from QEfficient.transformers.spd.turbo import build_and_attach_turbo + + assert build_and_attach_turbo is not None + + @pytest.mark.onnx + @pytest.mark.slow + def test_tlm_onnx_has_num_logits_to_keep_input(self, tmp_export_dir): + """TLM ONNX export must include 'num_logits_to_keep' as an input.""" + import onnx + + from QEfficient.transformers.models.modeling_auto import QEFFAutoModelForCausalLM + + model, cfg = make_tiny_llama() + qeff_model = QEFFAutoModelForCausalLM( + model, + qaic_config={"speculative_model_type": "target"}, + ) + onnx_path = qeff_model.export(export_dir=str(tmp_export_dir)) + onnx_model = onnx.load(str(onnx_path)) + + input_names = [inp.name for inp in onnx_model.graph.input] + assert "num_logits_to_keep" in input_names, ( + f"TLM ONNX must have 'num_logits_to_keep' input. Found: {input_names}" + ) + + @pytest.mark.onnx + @pytest.mark.slow + def test_tlm_onnx_logits_output_is_present(self, tmp_export_dir): + """TLM ONNX export must include 'logits' as an output.""" + import onnx + + from QEfficient.transformers.models.modeling_auto import QEFFAutoModelForCausalLM + + model, cfg = make_tiny_llama() + qeff_model = QEFFAutoModelForCausalLM( + model, + qaic_config={"speculative_model_type": "target"}, + ) + onnx_path = qeff_model.export(export_dir=str(tmp_export_dir)) + onnx_model = onnx.load(str(onnx_path)) + + output_names = [out.name for out in onnx_model.graph.output] + assert "logits" in output_names, f"TLM ONNX must have 'logits' output. Found: {output_names}" diff --git a/tests/unit_test/transforms/test_transform_accuracy.py b/tests/unit_test/transforms/test_transform_accuracy.py new file mode 100644 index 000000000..fed77f470 --- /dev/null +++ b/tests/unit_test/transforms/test_transform_accuracy.py @@ -0,0 +1,1652 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +# SPDX-License-Identifier: BSD-3-Clause +# +# ----------------------------------------------------------------------------- +""" +Accuracy tests for PyTorch transforms in QEfficient. + +Improvements over unit_v2: + - Expanded CustomOpsTransform coverage: Phi3, Gemma, Gemma2 + - Expanded KVCacheTransform coverage: Phi3, Gemma, Gemma2, Falcon + - Expanded combined transforms: Phi3, Gemma, Gemma2 + - SamplerTransform and SpDTransform behavior tests + +Tests verify that transforms: + 1. Replace the correct module types + 2. Do NOT change the model's numerical output (accuracy preservation) + 3. Work correctly in combination + +All tests run on CPU only, using tiny in-memory models. +""" + +import pytest +import torch +import torch.nn.functional as F +from transformers import ( + FalconConfig, + FalconForCausalLM, + Gemma2Config, + Gemma2ForCausalLM, + GemmaConfig, + GemmaForCausalLM, + GPT2Config, + GPT2LMHeadModel, + LlamaConfig, + LlamaForCausalLM, + MistralConfig, + MistralForCausalLM, + Phi3Config, + Phi3ForCausalLM, + Qwen2Config, + Qwen2ForCausalLM, +) + +from QEfficient.transformers.models.pytorch_transforms import ( + CustomOpsTransform, + KVCacheTransform, + PoolingTransform, + SamplerTransform, + SpDTransform, +) + +VOCAB_SIZE = 500 +SEQ_LEN = 8 +CTX_LEN = 32 + + +# --------------------------------------------------------------------------- +# Tiny model factories +# --------------------------------------------------------------------------- + + +def make_tiny_gpt2(): + cfg = GPT2Config(n_layer=1, n_head=2, n_embd=64, vocab_size=VOCAB_SIZE, n_positions=CTX_LEN, n_ctx=CTX_LEN) + return GPT2LMHeadModel(cfg).eval() + + +def make_tiny_llama(): + cfg = LlamaConfig( + num_hidden_layers=1, + num_attention_heads=2, + num_key_value_heads=2, + hidden_size=64, + intermediate_size=128, + vocab_size=VOCAB_SIZE, + max_position_embeddings=CTX_LEN, + ) + return LlamaForCausalLM(cfg).eval() + + +def make_tiny_mistral(): + cfg = MistralConfig( + num_hidden_layers=1, + num_attention_heads=2, + num_key_value_heads=2, + hidden_size=64, + intermediate_size=128, + vocab_size=VOCAB_SIZE, + max_position_embeddings=CTX_LEN, + ) + return MistralForCausalLM(cfg).eval() + + +def make_tiny_qwen2(): + cfg = Qwen2Config( + num_hidden_layers=1, + num_attention_heads=2, + num_key_value_heads=2, + hidden_size=64, + intermediate_size=128, + vocab_size=VOCAB_SIZE, + max_position_embeddings=CTX_LEN, + ) + return Qwen2ForCausalLM(cfg).eval() + + +def make_tiny_phi3(): + cfg = Phi3Config( + num_hidden_layers=1, + num_attention_heads=2, + num_key_value_heads=2, + hidden_size=64, + intermediate_size=128, + vocab_size=VOCAB_SIZE, + max_position_embeddings=CTX_LEN, + pad_token_id=0, + ) + return Phi3ForCausalLM(cfg).eval() + + +def make_tiny_gemma(): + cfg = GemmaConfig( + num_hidden_layers=1, + num_attention_heads=2, + num_key_value_heads=2, + hidden_size=64, + intermediate_size=128, + vocab_size=VOCAB_SIZE, + max_position_embeddings=CTX_LEN, + head_dim=32, + ) + return GemmaForCausalLM(cfg).eval() + + +def make_tiny_gemma2(): + cfg = Gemma2Config( + num_hidden_layers=2, + num_attention_heads=2, + num_key_value_heads=2, + hidden_size=64, + intermediate_size=128, + vocab_size=VOCAB_SIZE, + max_position_embeddings=CTX_LEN, + head_dim=32, + sliding_window=CTX_LEN, + ) + return Gemma2ForCausalLM(cfg).eval() + + +def make_tiny_falcon(): + cfg = FalconConfig( + num_hidden_layers=1, + num_attention_heads=2, + hidden_size=64, + vocab_size=VOCAB_SIZE, + max_position_embeddings=CTX_LEN, + new_decoder_architecture=False, + multi_query=True, + ) + return FalconForCausalLM(cfg).eval() + + +# --------------------------------------------------------------------------- +# QEff input helpers +# --------------------------------------------------------------------------- + + +def _get_dims(config): + """Extract (n_layers, n_kv_heads, head_dim) from any model config.""" + if hasattr(config, "num_hidden_layers"): + n_layers = config.num_hidden_layers + n_attn = config.num_attention_heads + n_kv = getattr(config, "num_key_value_heads", n_attn) + head_dim = getattr(config, "head_dim", None) or (config.hidden_size // n_attn) + else: + # GPT2 + n_layers = config.n_layer + n_kv = config.n_head + head_dim = config.n_embd // config.n_head + return n_layers, n_kv, head_dim + + +def _make_qeff_inputs(input_ids, config, ctx_len=CTX_LEN): + """Build QEff-style inputs: input_ids + position_ids + zero-initialized past_key_values.""" + batch, seq = input_ids.shape + position_ids = torch.arange(seq).unsqueeze(0).expand(batch, -1) + n_layers, n_kv, head_dim = _get_dims(config) + past_key_values = tuple( + ( + torch.zeros(batch, n_kv, ctx_len, head_dim, dtype=torch.float32), + torch.zeros(batch, n_kv, ctx_len, head_dim, dtype=torch.float32), + ) + for _ in range(n_layers) + ) + return { + "input_ids": input_ids, + "position_ids": position_ids, + "past_key_values": past_key_values, + } + + +# --------------------------------------------------------------------------- +# Tests: CustomOpsTransform - module replacement +# --------------------------------------------------------------------------- + + +@pytest.mark.transforms +class TestCustomOpsTransformReplacement: + """CustomOpsTransform must replace RMSNorm with CustomRMSNormAIC.""" + + def test_llama_rms_norm_replaced_with_custom_rms_norm(self): + from transformers.models.llama.modeling_llama import LlamaRMSNorm + + from QEfficient.customop import CustomRMSNormAIC + + model = make_tiny_llama() + assert any(isinstance(m, LlamaRMSNorm) for m in model.modules()) + + transformed, applied = CustomOpsTransform.apply(model) + assert applied + + for m in transformed.modules(): + if type(m) is LlamaRMSNorm: + pytest.fail("Found unreplaced LlamaRMSNorm after transform") + + assert any(isinstance(m, CustomRMSNormAIC) for m in transformed.modules()) + + def test_mistral_rms_norm_replaced(self): + from QEfficient.customop import CustomRMSNormAIC + + model = make_tiny_mistral() + transformed, applied = CustomOpsTransform.apply(model) + assert applied + assert any(isinstance(m, CustomRMSNormAIC) for m in transformed.modules()) + + def test_qwen2_rms_norm_replaced(self): + from QEfficient.customop import CustomRMSNormAIC + + model = make_tiny_qwen2() + transformed, applied = CustomOpsTransform.apply(model) + assert applied + assert any(isinstance(m, CustomRMSNormAIC) for m in transformed.modules()) + + def test_phi3_rms_norm_replaced(self): + from QEfficient.customop import CustomRMSNormAIC + + model = make_tiny_phi3() + transformed, applied = CustomOpsTransform.apply(model) + assert applied + assert any(isinstance(m, CustomRMSNormAIC) for m in transformed.modules()) + + def test_gemma_rms_norm_replaced(self): + from QEfficient.customop import GemmaCustomRMSNormAIC + + model = make_tiny_gemma() + transformed, applied = CustomOpsTransform.apply(model) + assert applied + assert any(isinstance(m, GemmaCustomRMSNormAIC) for m in transformed.modules()) + + def test_gemma2_rms_norm_replaced(self): + from QEfficient.customop import GemmaCustomRMSNormAIC + + model = make_tiny_gemma2() + transformed, applied = CustomOpsTransform.apply(model) + assert applied + assert any(isinstance(m, GemmaCustomRMSNormAIC) for m in transformed.modules()) + + def test_gpt2_not_transformed(self): + """GPT2 uses LayerNorm, not RMSNorm. CustomOpsTransform must not apply.""" + model = make_tiny_gpt2() + transformed, applied = CustomOpsTransform.apply(model) + assert not applied, "CustomOpsTransform must not apply to GPT2 (no RMSNorm)" + + def test_module_mapping_contains_expected_types(self): + from transformers.models.gemma.modeling_gemma import GemmaRMSNorm + from transformers.models.gemma2.modeling_gemma2 import Gemma2RMSNorm + from transformers.models.llama.modeling_llama import LlamaRMSNorm + from transformers.models.mistral.modeling_mistral import MistralRMSNorm + from transformers.models.phi3.modeling_phi3 import Phi3RMSNorm + from transformers.models.qwen2.modeling_qwen2 import Qwen2RMSNorm + + mapping = CustomOpsTransform._module_mapping + assert LlamaRMSNorm in mapping + assert MistralRMSNorm in mapping + assert Qwen2RMSNorm in mapping + assert Phi3RMSNorm in mapping + assert GemmaRMSNorm in mapping + assert Gemma2RMSNorm in mapping + + +# --------------------------------------------------------------------------- +# Tests: CustomOpsTransform - accuracy preservation +# --------------------------------------------------------------------------- + + +@pytest.mark.transforms +@pytest.mark.accuracy +class TestCustomOpsTransformAccuracy: + """ + CustomOpsTransform must NOT change the model's numerical output. + CustomRMSNormAIC must be numerically equivalent to LlamaRMSNorm. + """ + + def test_llama_output_unchanged_after_custom_ops_transform(self): + """Llama logits must be identical before and after CustomOpsTransform.""" + model = make_tiny_llama() + input_ids = torch.randint(0, VOCAB_SIZE, (1, SEQ_LEN)) + + with torch.no_grad(): + before_logits = model(input_ids=input_ids).logits[:, -1, :] + + transformed, _ = CustomOpsTransform.apply(model) + with torch.no_grad(): + after_logits = transformed(input_ids=input_ids).logits[:, -1, :] + + max_diff = (before_logits - after_logits).abs().max().item() + assert max_diff < 1e-5, ( + f"CustomOpsTransform changed Llama output: max_diff={max_diff:.2e}. " + f"CustomRMSNormAIC must be numerically equivalent to LlamaRMSNorm." + ) + + def test_llama_greedy_token_unchanged_after_custom_ops_transform(self): + model = make_tiny_llama() + input_ids = torch.randint(0, VOCAB_SIZE, (1, SEQ_LEN)) + + with torch.no_grad(): + before_token = model(input_ids=input_ids).logits[:, -1, :].argmax(-1).item() + + transformed, _ = CustomOpsTransform.apply(model) + with torch.no_grad(): + after_token = transformed(input_ids=input_ids).logits[:, -1, :].argmax(-1).item() + + assert before_token == after_token, ( + f"CustomOpsTransform changed greedy token: before={before_token}, after={after_token}" + ) + + def test_mistral_output_unchanged_after_custom_ops_transform(self): + model = make_tiny_mistral() + input_ids = torch.randint(0, VOCAB_SIZE, (1, SEQ_LEN)) + + with torch.no_grad(): + before_logits = model(input_ids=input_ids).logits[:, -1, :] + + transformed, _ = CustomOpsTransform.apply(model) + with torch.no_grad(): + after_logits = transformed(input_ids=input_ids).logits[:, -1, :] + + max_diff = (before_logits - after_logits).abs().max().item() + assert max_diff < 1e-5, f"CustomOpsTransform changed Mistral output: max_diff={max_diff:.2e}" + + def test_phi3_output_unchanged_after_custom_ops_transform(self): + model = make_tiny_phi3() + input_ids = torch.randint(0, VOCAB_SIZE, (1, SEQ_LEN)) + + with torch.no_grad(): + before_logits = model(input_ids=input_ids).logits[:, -1, :] + + transformed, _ = CustomOpsTransform.apply(model) + with torch.no_grad(): + after_logits = transformed(input_ids=input_ids).logits[:, -1, :] + + max_diff = (before_logits - after_logits).abs().max().item() + assert max_diff < 1e-5, f"CustomOpsTransform changed Phi3 output: max_diff={max_diff:.2e}" + + def test_gemma_output_unchanged_after_custom_ops_transform(self): + model = make_tiny_gemma() + input_ids = torch.randint(0, VOCAB_SIZE, (1, SEQ_LEN)) + + with torch.no_grad(): + before_logits = model(input_ids=input_ids).logits[:, -1, :] + + transformed, _ = CustomOpsTransform.apply(model) + with torch.no_grad(): + after_logits = transformed(input_ids=input_ids).logits[:, -1, :] + + max_diff = (before_logits - after_logits).abs().max().item() + assert max_diff < 1e-5, f"CustomOpsTransform changed Gemma output: max_diff={max_diff:.2e}" + + def test_custom_rms_norm_forward_is_finite(self): + """CustomRMSNormAIC forward must produce finite outputs.""" + model = make_tiny_llama() + transformed, _ = CustomOpsTransform.apply(model) + input_ids = torch.randint(0, VOCAB_SIZE, (1, SEQ_LEN)) + with torch.no_grad(): + out = transformed(input_ids=input_ids) + assert torch.isfinite(out.logits).all() + + +# --------------------------------------------------------------------------- +# Tests: KVCacheTransform - module replacement +# --------------------------------------------------------------------------- + + +@pytest.mark.transforms +class TestKVCacheTransformReplacement: + """KVCacheTransform must replace attention layers with QEff variants.""" + + def test_gpt2_attention_replaced(self): + from transformers.models.gpt2.modeling_gpt2 import GPT2Attention + + from QEfficient.transformers.models.gpt2.modeling_gpt2 import QEffGPT2Attention + + model = make_tiny_gpt2() + transformed, applied = KVCacheTransform.apply(model) + assert applied + + for m in transformed.modules(): + if isinstance(m, GPT2Attention): + assert isinstance(m, QEffGPT2Attention) + + def test_gpt2_lm_head_model_replaced(self): + from QEfficient.transformers.models.gpt2.modeling_gpt2 import QEffGPT2LMHeadModel + + model = make_tiny_gpt2() + transformed, _ = KVCacheTransform.apply(model) + assert isinstance(transformed, QEffGPT2LMHeadModel) + + def test_llama_attention_replaced(self): + from transformers.models.llama.modeling_llama import LlamaAttention + + from QEfficient.transformers.models.llama.modeling_llama import QEffLlamaAttention + + model = make_tiny_llama() + transformed, applied = KVCacheTransform.apply(model) + assert applied + + for m in transformed.modules(): + if isinstance(m, LlamaAttention): + assert isinstance(m, QEffLlamaAttention) + + def test_llama_for_causal_lm_replaced(self): + from QEfficient.transformers.models.llama.modeling_llama import QEffLlamaForCausalLM + + model = make_tiny_llama() + transformed, _ = KVCacheTransform.apply(model) + assert isinstance(transformed, QEffLlamaForCausalLM) + + def test_mistral_attention_replaced(self): + from transformers.models.mistral.modeling_mistral import MistralAttention + + from QEfficient.transformers.models.mistral.modeling_mistral import QEffMistralAttention + + model = make_tiny_mistral() + transformed, applied = KVCacheTransform.apply(model) + assert applied + + for m in transformed.modules(): + if isinstance(m, MistralAttention): + assert isinstance(m, QEffMistralAttention) + + def test_qwen2_attention_replaced(self): + from transformers.models.qwen2.modeling_qwen2 import Qwen2Attention + + from QEfficient.transformers.models.qwen2.modeling_qwen2 import QEffQwen2Attention + + model = make_tiny_qwen2() + transformed, applied = KVCacheTransform.apply(model) + assert applied + + for m in transformed.modules(): + if isinstance(m, Qwen2Attention): + assert isinstance(m, QEffQwen2Attention) + + def test_phi3_attention_replaced(self): + from transformers.models.phi3.modeling_phi3 import Phi3Attention + + from QEfficient.transformers.models.phi3.modeling_phi3 import QEffPhi3Attention + + model = make_tiny_phi3() + transformed, applied = KVCacheTransform.apply(model) + assert applied + + for m in transformed.modules(): + if isinstance(m, Phi3Attention): + assert isinstance(m, QEffPhi3Attention) + + def test_gemma_attention_replaced(self): + from transformers.models.gemma.modeling_gemma import GemmaAttention + + from QEfficient.transformers.models.gemma.modeling_gemma import QEffGemmaAttention + + model = make_tiny_gemma() + transformed, applied = KVCacheTransform.apply(model) + assert applied + + for m in transformed.modules(): + if isinstance(m, GemmaAttention): + assert isinstance(m, QEffGemmaAttention) + + def test_falcon_attention_replaced(self): + from transformers.models.falcon.modeling_falcon import FalconAttention + + from QEfficient.transformers.models.falcon.modeling_falcon import QEffFalconAttention + + model = make_tiny_falcon() + transformed, applied = KVCacheTransform.apply(model) + assert applied + + for m in transformed.modules(): + if isinstance(m, FalconAttention): + assert isinstance(m, QEffFalconAttention) + + def test_module_mapping_covers_major_architectures(self): + from transformers.models.falcon.modeling_falcon import FalconForCausalLM + from transformers.models.gemma.modeling_gemma import GemmaForCausalLM + from transformers.models.gpt2.modeling_gpt2 import GPT2LMHeadModel + from transformers.models.llama.modeling_llama import LlamaForCausalLM + from transformers.models.mistral.modeling_mistral import MistralForCausalLM + from transformers.models.mixtral.modeling_mixtral import MixtralForCausalLM + from transformers.models.phi3.modeling_phi3 import Phi3ForCausalLM + from transformers.models.qwen2.modeling_qwen2 import Qwen2ForCausalLM + + mapping = KVCacheTransform._module_mapping + assert GPT2LMHeadModel in mapping + assert LlamaForCausalLM in mapping + assert MistralForCausalLM in mapping + assert MixtralForCausalLM in mapping + assert Qwen2ForCausalLM in mapping + assert Phi3ForCausalLM in mapping + assert GemmaForCausalLM in mapping + assert FalconForCausalLM in mapping + + +# --------------------------------------------------------------------------- +# Tests: KVCacheTransform - accuracy preservation +# --------------------------------------------------------------------------- + + +@pytest.mark.transforms +@pytest.mark.accuracy +class TestKVCacheTransformAccuracy: + """ + KVCacheTransform must NOT change the model's greedy next token prediction. + This is the core regression test for the KV cache transform. + """ + + def _check_greedy_token_preserved(self, model, label): + input_ids = torch.randint(0, VOCAB_SIZE, (1, SEQ_LEN)) + + with torch.no_grad(): + before_token = model(input_ids=input_ids).logits[:, -1, :].argmax(-1).item() + + cfg = model.config + transformed, _ = KVCacheTransform.apply(model) + qeff_inputs = _make_qeff_inputs(input_ids, cfg) + + with torch.no_grad(): + after_out = transformed(**qeff_inputs) + after_token = after_out.logits[:, -1, :].argmax(-1).item() + + assert before_token == after_token, ( + f"[{label}] KVCacheTransform changed greedy token: " + f"before={before_token}, after={after_token}. " + f"KVCacheTransform must not change the model's prediction." + ) + + def test_gpt2_greedy_token_preserved_after_kv_transform(self): + self._check_greedy_token_preserved(make_tiny_gpt2(), "GPT2") + + def test_llama_greedy_token_preserved_after_kv_transform(self): + self._check_greedy_token_preserved(make_tiny_llama(), "Llama") + + def test_mistral_greedy_token_preserved_after_kv_transform(self): + self._check_greedy_token_preserved(make_tiny_mistral(), "Mistral") + + def test_qwen2_greedy_token_preserved_after_kv_transform(self): + self._check_greedy_token_preserved(make_tiny_qwen2(), "Qwen2") + + def test_phi3_greedy_token_preserved_after_kv_transform(self): + self._check_greedy_token_preserved(make_tiny_phi3(), "Phi3") + + def test_gemma_greedy_token_preserved_after_kv_transform(self): + self._check_greedy_token_preserved(make_tiny_gemma(), "Gemma") + + def test_gpt2_logits_numerically_close_after_kv_transform(self): + """GPT2 logits must be numerically close before and after KVCacheTransform.""" + model = make_tiny_gpt2() + input_ids = torch.randint(0, VOCAB_SIZE, (1, SEQ_LEN)) + + with torch.no_grad(): + before_logits = model(input_ids=input_ids).logits[:, -1, :] + + cfg = model.config + transformed, _ = KVCacheTransform.apply(model) + qeff_inputs = _make_qeff_inputs(input_ids, cfg) + with torch.no_grad(): + after_logits = transformed(**qeff_inputs).logits[:, -1, :] + + hf_probs = F.softmax(before_logits, dim=-1) + qeff_probs = F.softmax(after_logits, dim=-1) + max_diff = (hf_probs - qeff_probs).abs().max().item() + assert max_diff < 1e-3, f"KVCacheTransform changed GPT2 probability distribution: max_diff={max_diff:.2e}" + + def test_llama_logits_numerically_close_after_kv_transform(self): + model = make_tiny_llama() + input_ids = torch.randint(0, VOCAB_SIZE, (1, SEQ_LEN)) + + with torch.no_grad(): + before_logits = model(input_ids=input_ids).logits[:, -1, :] + + cfg = model.config + transformed, _ = KVCacheTransform.apply(model) + qeff_inputs = _make_qeff_inputs(input_ids, cfg) + with torch.no_grad(): + after_logits = transformed(**qeff_inputs).logits[:, -1, :] + + hf_probs = F.softmax(before_logits, dim=-1) + qeff_probs = F.softmax(after_logits, dim=-1) + max_diff = (hf_probs - qeff_probs).abs().max().item() + assert max_diff < 1e-3, f"KVCacheTransform changed Llama probability distribution: max_diff={max_diff:.2e}" + + def test_phi3_logits_numerically_close_after_kv_transform(self): + model = make_tiny_phi3() + input_ids = torch.randint(0, VOCAB_SIZE, (1, SEQ_LEN)) + + with torch.no_grad(): + before_logits = model(input_ids=input_ids).logits[:, -1, :] + + cfg = model.config + transformed, _ = KVCacheTransform.apply(model) + qeff_inputs = _make_qeff_inputs(input_ids, cfg) + with torch.no_grad(): + after_logits = transformed(**qeff_inputs).logits[:, -1, :] + + hf_probs = F.softmax(before_logits, dim=-1) + qeff_probs = F.softmax(after_logits, dim=-1) + max_diff = (hf_probs - qeff_probs).abs().max().item() + assert max_diff < 1e-3, f"KVCacheTransform changed Phi3 probability distribution: max_diff={max_diff:.2e}" + + +# --------------------------------------------------------------------------- +# Tests: Combined transforms accuracy +# --------------------------------------------------------------------------- + + +@pytest.mark.transforms +@pytest.mark.accuracy +class TestCombinedTransformsAccuracy: + """ + Applying CustomOpsTransform + KVCacheTransform together must preserve accuracy. + This is the exact combination used by QEFFAutoModelForCausalLM. + """ + + def _check_combined_transforms(self, model, label): + input_ids = torch.randint(0, VOCAB_SIZE, (1, SEQ_LEN)) + + with torch.no_grad(): + original_token = model(input_ids=input_ids).logits[:, -1, :].argmax(-1).item() + + cfg = model.config + model, _ = CustomOpsTransform.apply(model) + model, _ = KVCacheTransform.apply(model) + + qeff_inputs = _make_qeff_inputs(input_ids, cfg) + with torch.no_grad(): + transformed_token = model(**qeff_inputs).logits[:, -1, :].argmax(-1).item() + + assert original_token == transformed_token, ( + f"[{label}] Combined transforms changed greedy token: " + f"original={original_token}, transformed={transformed_token}" + ) + + def test_llama_combined_transforms_preserve_greedy_token(self): + self._check_combined_transforms(make_tiny_llama(), "Llama") + + def test_mistral_combined_transforms_preserve_greedy_token(self): + self._check_combined_transforms(make_tiny_mistral(), "Mistral") + + def test_qwen2_combined_transforms_preserve_greedy_token(self): + self._check_combined_transforms(make_tiny_qwen2(), "Qwen2") + + def test_phi3_combined_transforms_preserve_greedy_token(self): + self._check_combined_transforms(make_tiny_phi3(), "Phi3") + + def test_gemma_combined_transforms_preserve_greedy_token(self): + self._check_combined_transforms(make_tiny_gemma(), "Gemma") + + def test_combined_transforms_produce_finite_outputs(self): + """Combined transforms must produce finite logits for all supported models.""" + for factory, label in [ + (make_tiny_llama, "Llama"), + (make_tiny_mistral, "Mistral"), + (make_tiny_qwen2, "Qwen2"), + (make_tiny_phi3, "Phi3"), + ]: + model = factory() + cfg = model.config + model, _ = CustomOpsTransform.apply(model) + model, _ = KVCacheTransform.apply(model) + + input_ids = torch.randint(0, VOCAB_SIZE, (1, SEQ_LEN)) + qeff_inputs = _make_qeff_inputs(input_ids, cfg) + with torch.no_grad(): + out = model(**qeff_inputs) + assert torch.isfinite(out.logits).all(), f"{label} combined transforms produce NaN/Inf" + + def test_gpt2_kv_transform_then_custom_ops_no_crash(self): + """Applying KVCacheTransform then CustomOpsTransform to GPT2 must not crash.""" + model = make_tiny_gpt2() + model, _ = KVCacheTransform.apply(model) + model, applied = CustomOpsTransform.apply(model) + assert not applied, "CustomOpsTransform must not apply to GPT2" + + +# --------------------------------------------------------------------------- +# Tests: PoolingTransform +# --------------------------------------------------------------------------- + + +@pytest.mark.transforms +class TestPoolingTransformCorrectness: + """PoolingTransform must produce correct pooled embeddings.""" + + def test_mean_pooling_wraps_model(self): + from transformers import BertConfig, BertModel + + from QEfficient.transformers.embeddings.embedding_utils import PooledModel + + cfg = BertConfig( + num_hidden_layers=1, + num_attention_heads=2, + hidden_size=64, + intermediate_size=128, + vocab_size=500, + max_position_embeddings=64, + ) + model = BertModel(cfg).eval() + pooled, applied = PoolingTransform.apply(model, pooling="mean") + assert isinstance(pooled, PooledModel) + + def test_cls_pooling_wraps_model(self): + from transformers import BertConfig, BertModel + + from QEfficient.transformers.embeddings.embedding_utils import PooledModel + + cfg = BertConfig( + num_hidden_layers=1, + num_attention_heads=2, + hidden_size=64, + intermediate_size=128, + vocab_size=500, + max_position_embeddings=64, + ) + model = BertModel(cfg).eval() + pooled, applied = PoolingTransform.apply(model, pooling="cls") + assert isinstance(pooled, PooledModel) + + def test_invalid_pooling_raises_error(self): + from transformers import BertConfig, BertModel + + cfg = BertConfig( + num_hidden_layers=1, + num_attention_heads=2, + hidden_size=64, + intermediate_size=128, + vocab_size=500, + max_position_embeddings=64, + ) + model = BertModel(cfg).eval() + with pytest.raises((ValueError, KeyError, TypeError)): + PoolingTransform.apply(model, pooling="invalid_pooling_xyz") + + def test_mean_pooled_output_matches_manual_mean(self): + """PooledModel mean output must match manually computed mean pooling.""" + from transformers import BertConfig, BertModel + + cfg = BertConfig( + num_hidden_layers=1, + num_attention_heads=2, + hidden_size=64, + intermediate_size=128, + vocab_size=500, + max_position_embeddings=64, + ) + model = BertModel(cfg).eval() + inputs = { + "input_ids": torch.randint(0, 500, (1, 16)), + "attention_mask": torch.ones(1, 16, dtype=torch.long), + } + + with torch.no_grad(): + hf_out = model(**inputs) + mask = inputs["attention_mask"].unsqueeze(-1).float() + manual_mean = (hf_out.last_hidden_state * mask).sum(1) / mask.sum(1) + + pooled, _ = PoolingTransform.apply(model, pooling="mean") + with torch.no_grad(): + pooled_mean = pooled(**inputs) + + max_diff = (manual_mean - pooled_mean).abs().max().item() + assert max_diff < 1e-5, f"Mean pooling mismatch: max_diff={max_diff:.2e}" + + def test_max_pooling_wraps_model(self): + """PoolingTransform with pooling='max' must wrap the model in PooledModel.""" + from transformers import BertConfig, BertModel + + from QEfficient.transformers.embeddings.embedding_utils import PooledModel + + cfg = BertConfig( + num_hidden_layers=1, + num_attention_heads=2, + hidden_size=64, + intermediate_size=128, + vocab_size=500, + max_position_embeddings=64, + ) + model = BertModel(cfg).eval() + pooled, _ = PoolingTransform.apply(model, pooling="max") + # PoolingTransform always returns applied=False (it wraps, not replaces) + assert isinstance(pooled, PooledModel) + + def test_max_pooled_output_matches_manual_max(self): + """PooledModel max output must match manually computed max pooling.""" + from transformers import BertConfig, BertModel + + cfg = BertConfig( + num_hidden_layers=1, + num_attention_heads=2, + hidden_size=64, + intermediate_size=128, + vocab_size=500, + max_position_embeddings=64, + ) + model = BertModel(cfg).eval() + inputs = { + "input_ids": torch.randint(0, 500, (1, 16)), + "attention_mask": torch.ones(1, 16, dtype=torch.long), + } + + with torch.no_grad(): + hf_out = model(**inputs) + # Manual max pooling: max over sequence dimension + manual_max = hf_out.last_hidden_state.max(dim=1).values + + pooled, _ = PoolingTransform.apply(model, pooling="max") + with torch.no_grad(): + pooled_max = pooled(**inputs) + + max_diff = (manual_max - pooled_max).abs().max().item() + assert max_diff < 1e-5, f"Max pooling mismatch: max_diff={max_diff:.2e}" + + def test_avg_pooling_wraps_model(self): + """PoolingTransform with pooling='avg' must wrap the model in PooledModel.""" + from transformers import BertConfig, BertModel + + from QEfficient.transformers.embeddings.embedding_utils import PooledModel + + cfg = BertConfig( + num_hidden_layers=1, + num_attention_heads=2, + hidden_size=64, + intermediate_size=128, + vocab_size=500, + max_position_embeddings=64, + ) + model = BertModel(cfg).eval() + # 'avg' is supported in POOLING_MAP + pooled, _ = PoolingTransform.apply(model, pooling="avg") + assert isinstance(pooled, PooledModel) + + def test_custom_callable_pooling_is_accepted(self): + """PoolingTransform must accept a callable as the pooling argument.""" + from transformers import BertConfig, BertModel + + from QEfficient.transformers.embeddings.embedding_utils import PooledModel + + cfg = BertConfig( + num_hidden_layers=1, + num_attention_heads=2, + hidden_size=64, + intermediate_size=128, + vocab_size=500, + max_position_embeddings=64, + ) + model = BertModel(cfg).eval() + + def custom_pool(last_hidden_states, attention_mask): + # Simple: return first token (like CLS) + return last_hidden_states[:, 0, :] + + try: + pooled, _ = PoolingTransform.apply(model, pooling=custom_pool) + assert isinstance(pooled, PooledModel) + except (ValueError, TypeError, NotImplementedError): + # If custom callable is not supported, skip + pytest.skip("Custom callable pooling not supported in this version") + + def test_pooling_output_is_finite(self): + """Pooled output must be finite (no NaN/Inf).""" + from transformers import BertConfig, BertModel + + cfg = BertConfig( + num_hidden_layers=1, + num_attention_heads=2, + hidden_size=64, + intermediate_size=128, + vocab_size=500, + max_position_embeddings=64, + ) + model = BertModel(cfg).eval() + inputs = { + "input_ids": torch.randint(0, 500, (1, 16)), + "attention_mask": torch.ones(1, 16, dtype=torch.long), + } + + for pooling_type in ["mean", "cls", "max"]: + try: + pooled, _ = PoolingTransform.apply(model, pooling=pooling_type) + with torch.no_grad(): + output = pooled(**inputs) + assert torch.isfinite(output).all(), f"Pooled output for '{pooling_type}' must be finite" + except (ValueError, KeyError): + pass # Skip unsupported pooling types + + +# --------------------------------------------------------------------------- +# Tests: SamplerTransform +# --------------------------------------------------------------------------- + + +@pytest.mark.transforms +class TestSamplerTransformBehavior: + """SamplerTransform must only apply when qaic_config has include_sampler=True.""" + + def test_no_transform_when_qaic_config_is_none(self): + model = make_tiny_gpt2() + kv_model, _ = KVCacheTransform.apply(model) + _, applied = SamplerTransform.apply(kv_model, qaic_config=None) + assert not applied + + def test_no_transform_when_include_sampler_false(self): + model = make_tiny_gpt2() + kv_model, _ = KVCacheTransform.apply(model) + _, applied = SamplerTransform.apply(kv_model, qaic_config={"include_sampler": False}) + assert not applied + + def test_unsupported_model_raises_not_implemented(self): + import torch.nn as nn + + class UnsupportedModel(nn.Module): + def forward(self, x): + return x + + with pytest.raises(NotImplementedError): + SamplerTransform.apply(UnsupportedModel(), qaic_config={"include_sampler": True}) + + def test_supported_model_classes_include_gpt2_and_llama(self): + from QEfficient.transformers.models.gpt2.modeling_gpt2 import QEffGPT2LMHeadModel + from QEfficient.transformers.models.llama.modeling_llama import QEffLlamaForCausalLM + + assert QEffGPT2LMHeadModel in SamplerTransform._module_mapping + assert QEffLlamaForCausalLM in SamplerTransform._module_mapping + + +# --------------------------------------------------------------------------- +# Tests: SpDTransform +# --------------------------------------------------------------------------- + + +@pytest.mark.transforms +class TestSpDTransformBehavior: + """SpDTransform must only apply when speculative_model_type is in qaic_config.""" + + def test_no_transform_when_qaic_config_is_none(self): + model = make_tiny_llama() + kv_model, _ = KVCacheTransform.apply(model) + _, applied = SpDTransform.apply(kv_model, qaic_config=None) + assert not applied + + def test_no_transform_when_speculative_model_type_missing(self): + model = make_tiny_llama() + kv_model, _ = KVCacheTransform.apply(model) + _, applied = SpDTransform.apply(kv_model, qaic_config={}) + assert not applied + + def test_invalid_speculative_model_type_raises_value_error(self): + model = make_tiny_llama() + kv_model, _ = KVCacheTransform.apply(model) + with pytest.raises(ValueError): + SpDTransform.apply(kv_model, qaic_config={"speculative_model_type": "invalid_xyz"}) + + def test_module_mapping_contains_llama_and_qwen2(self): + from QEfficient.transformers.models.llama.modeling_llama import QEffLlamaForCausalLM + from QEfficient.transformers.models.qwen2.modeling_qwen2 import QEffQwen2ForCausalLM + + assert QEffLlamaForCausalLM in SpDTransform._module_mapping + assert QEffQwen2ForCausalLM in SpDTransform._module_mapping + + +# --------------------------------------------------------------------------- +# Tests: SamplerTransform actual apply +# --------------------------------------------------------------------------- + + +@pytest.mark.transforms +class TestSamplerTransformActualApply: + """SamplerTransform with include_sampler=True must attach sampler_forward.""" + + def test_sampler_transform_applies_to_gpt2_with_include_sampler_true(self): + """SamplerTransform must apply to QEffGPT2LMHeadModel when include_sampler=True.""" + model = make_tiny_gpt2() + kv_model, _ = KVCacheTransform.apply(model) + _, applied = SamplerTransform.apply(kv_model, qaic_config={"include_sampler": True}) + assert applied, "SamplerTransform must apply when include_sampler=True" + + def test_sampler_transform_applies_to_llama_with_include_sampler_true(self): + """SamplerTransform must apply to QEffLlamaForCausalLM when include_sampler=True.""" + model = make_tiny_llama() + kv_model, _ = KVCacheTransform.apply(model) + _, applied = SamplerTransform.apply(kv_model, qaic_config={"include_sampler": True}) + assert applied, "SamplerTransform must apply to Llama when include_sampler=True" + + def test_sampler_transform_saves_old_forward(self): + """After SamplerTransform, model.old_forward must be set to the original forward.""" + model = make_tiny_gpt2() + kv_model, _ = KVCacheTransform.apply(model) + original_forward = kv_model.forward + SamplerTransform.apply(kv_model, qaic_config={"include_sampler": True}) + assert hasattr(kv_model, "old_forward"), "SamplerTransform must save old_forward" + assert kv_model.old_forward == original_forward, "old_forward must be the original forward method" + + def test_sampler_transform_replaces_forward_with_sampler_forward(self): + """After SamplerTransform, model.forward must be replaced.""" + model = make_tiny_gpt2() + kv_model, _ = KVCacheTransform.apply(model) + original_forward = kv_model.forward + SamplerTransform.apply(kv_model, qaic_config={"include_sampler": True}) + # The forward must have been replaced + assert kv_model.forward is not original_forward, "SamplerTransform must replace model.forward" + + def test_sampler_transform_returns_same_model_instance(self): + """SamplerTransform must modify model in-place.""" + model = make_tiny_gpt2() + kv_model, _ = KVCacheTransform.apply(model) + transformed, applied = SamplerTransform.apply(kv_model, qaic_config={"include_sampler": True}) + assert applied + assert transformed is kv_model, "SamplerTransform must modify model in-place" + + def test_sampler_transform_module_mapping_contains_gpt2_and_llama(self): + from QEfficient.transformers.models.gpt2.modeling_gpt2 import QEffGPT2LMHeadModel + from QEfficient.transformers.models.llama.modeling_llama import QEffLlamaForCausalLM + + assert QEffGPT2LMHeadModel in SamplerTransform._module_mapping + assert QEffLlamaForCausalLM in SamplerTransform._module_mapping + + def test_sampler_transform_module_mapping_contains_phi3_and_qwen2(self): + from QEfficient.transformers.models.phi3.modeling_phi3 import QEffPhi3ForCausalLM + from QEfficient.transformers.models.qwen2.modeling_qwen2 import QEffQwen2ForCausalLM + + assert QEffPhi3ForCausalLM in SamplerTransform._module_mapping + assert QEffQwen2ForCausalLM in SamplerTransform._module_mapping + + +# --------------------------------------------------------------------------- +# Tests: MoE transform (Mixtral) +# --------------------------------------------------------------------------- + + +@pytest.mark.transforms +class TestMoETransformReplacement: + """KVCacheTransform must replace MixtralSparseMoeBlock with QEffMixtralSparseMoeBlock.""" + + def _make_tiny_mixtral(self): + from transformers import MixtralConfig, MixtralForCausalLM + + cfg = MixtralConfig( + num_hidden_layers=1, + num_attention_heads=2, + num_key_value_heads=2, + hidden_size=64, + intermediate_size=128, + vocab_size=VOCAB_SIZE, + max_position_embeddings=CTX_LEN, + num_experts_per_tok=2, + num_local_experts=4, + ) + return MixtralForCausalLM(cfg).eval(), cfg + + def test_mixtral_sparse_moe_block_replaced(self): + from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock + + from QEfficient.transformers.models.mixtral_moe.modeling_mixtral import QEffMixtralSparseMoeBlock + + model, cfg = self._make_tiny_mixtral() + assert any(isinstance(m, MixtralSparseMoeBlock) for m in model.modules()) + + transformed, applied = KVCacheTransform.apply(model) + assert applied + + for m in transformed.modules(): + if type(m) is MixtralSparseMoeBlock: + pytest.fail("Found unreplaced MixtralSparseMoeBlock after transform") + + assert any(isinstance(m, QEffMixtralSparseMoeBlock) for m in transformed.modules()) + + def test_mixtral_for_causal_lm_replaced(self): + from QEfficient.transformers.models.mixtral_moe.modeling_mixtral import QEffMixtralForCausalLM + + model, cfg = self._make_tiny_mixtral() + transformed, _ = KVCacheTransform.apply(model) + assert isinstance(transformed, QEffMixtralForCausalLM) + + def test_mixtral_greedy_token_preserved_after_kv_transform(self): + """Mixtral greedy token must be preserved after KVCacheTransform.""" + model, cfg = self._make_tiny_mixtral() + input_ids = torch.randint(0, VOCAB_SIZE, (1, SEQ_LEN)) + + with torch.no_grad(): + before_token = model(input_ids=input_ids).logits[:, -1, :].argmax(-1).item() + + transformed, _ = KVCacheTransform.apply(model) + qeff_inputs = _make_qeff_inputs(input_ids, cfg) + with torch.no_grad(): + after_token = transformed(**qeff_inputs).logits[:, -1, :].argmax(-1).item() + + assert before_token == after_token, ( + f"Mixtral KVCacheTransform changed greedy token: before={before_token}, after={after_token}" + ) + + def test_mixtral_kv_transform_produces_finite_outputs(self): + model, cfg = self._make_tiny_mixtral() + transformed, _ = KVCacheTransform.apply(model) + input_ids = torch.randint(0, VOCAB_SIZE, (1, SEQ_LEN)) + qeff_inputs = _make_qeff_inputs(input_ids, cfg) + with torch.no_grad(): + out = transformed(**qeff_inputs) + assert torch.isfinite(out.logits).all(), "Mixtral KVCacheTransform must produce finite logits" + + +# --------------------------------------------------------------------------- +# Tests: T5ModelTransform +# --------------------------------------------------------------------------- + + +@pytest.mark.transforms +class TestT5ModelTransform: + """T5ModelTransform must replace T5Attention and T5LayerNorm with QEff variants.""" + + def _make_tiny_t5(self): + from transformers import T5Config, T5ForConditionalGeneration + + cfg = T5Config( + num_heads=2, + d_model=64, + d_ff=128, + d_kv=32, + num_layers=1, + num_decoder_layers=1, + vocab_size=500, + relative_attention_num_buckets=8, + relative_attention_max_distance=16, + ) + return T5ForConditionalGeneration(cfg).eval(), cfg + + def test_t5_transform_importable(self): + from QEfficient.transformers.models.pytorch_transforms import T5ModelTransform + + assert T5ModelTransform is not None + + def test_t5_transform_has_module_mapping(self): + from QEfficient.transformers.models.pytorch_transforms import T5ModelTransform + + assert hasattr(T5ModelTransform, "_module_mapping") + assert len(T5ModelTransform._module_mapping) > 0 + + def test_t5_transform_maps_t5_attention(self): + from transformers.models.t5.modeling_t5 import T5Attention + + from QEfficient.transformers.models.pytorch_transforms import T5ModelTransform + + assert T5Attention in T5ModelTransform._module_mapping + qeff_cls = T5ModelTransform._module_mapping[T5Attention] + assert qeff_cls.__name__ == "QEffT5Attention" + + def test_t5_transform_maps_t5_layer_norm(self): + from transformers.models.t5.modeling_t5 import T5LayerNorm + + from QEfficient.transformers.models.pytorch_transforms import T5ModelTransform + + assert T5LayerNorm in T5ModelTransform._module_mapping + qeff_cls = T5ModelTransform._module_mapping[T5LayerNorm] + assert qeff_cls.__name__ == "QEffT5LayerNorm" + + def test_t5_transform_replaces_attention(self): + from transformers.models.t5.modeling_t5 import T5Attention + + from QEfficient.transformers.models.pytorch_transforms import T5ModelTransform + + model, cfg = self._make_tiny_t5() + assert any(isinstance(m, T5Attention) for m in model.modules()) + + transformed, applied = T5ModelTransform.apply(model) + assert applied + + qeff_t5_attn_cls = T5ModelTransform._module_mapping[T5Attention] + for m in transformed.modules(): + if type(m) is T5Attention: + pytest.fail("Found unreplaced T5Attention after T5ModelTransform") + + assert any(isinstance(m, qeff_t5_attn_cls) for m in transformed.modules()) + + def test_t5_transform_replaces_layer_norm(self): + from transformers.models.t5.modeling_t5 import T5LayerNorm + + from QEfficient.transformers.models.pytorch_transforms import T5ModelTransform + + model, cfg = self._make_tiny_t5() + transformed, applied = T5ModelTransform.apply(model) + assert applied + qeff_t5_ln_cls = T5ModelTransform._module_mapping[T5LayerNorm] + assert any(isinstance(m, qeff_t5_ln_cls) for m in transformed.modules()) + + def test_t5_transform_has_apply_method(self): + from QEfficient.transformers.models.pytorch_transforms import T5ModelTransform + + assert hasattr(T5ModelTransform, "apply") + assert callable(T5ModelTransform.apply) + + +# --------------------------------------------------------------------------- +# Tests: TextClassificationTransform +# --------------------------------------------------------------------------- + + +@pytest.mark.transforms +class TestTextClassificationTransformDirect: + """TextClassificationTransform must directly replace DisentangledSelfAttention.""" + + def _make_tiny_deberta(self): + from transformers import DebertaV2Config, DebertaV2ForSequenceClassification + + cfg = DebertaV2Config( + num_hidden_layers=1, + num_attention_heads=2, + hidden_size=64, + intermediate_size=128, + vocab_size=500, + max_position_embeddings=64, + num_labels=3, + type_vocab_size=0, + pos_att_type=["p2c", "c2p"], + ) + return DebertaV2ForSequenceClassification(cfg).eval(), cfg + + def test_text_classification_transform_importable(self): + from QEfficient.transformers.models.pytorch_transforms import TextClassificationTransform + + assert TextClassificationTransform is not None + + def test_text_classification_transform_has_module_mapping(self): + from QEfficient.transformers.models.pytorch_transforms import TextClassificationTransform + + assert hasattr(TextClassificationTransform, "_module_mapping") + assert len(TextClassificationTransform._module_mapping) > 0 + + def test_text_classification_transform_maps_disentangled_self_attention(self): + from transformers.models.deberta_v2.modeling_deberta_v2 import DisentangledSelfAttention + + from QEfficient.transformers.models.pytorch_transforms import TextClassificationTransform + + assert DisentangledSelfAttention in TextClassificationTransform._module_mapping + qeff_cls = TextClassificationTransform._module_mapping[DisentangledSelfAttention] + assert qeff_cls.__name__ == "QEffDisentangledSelfAttention" + + def test_text_classification_transform_replaces_attention(self): + from transformers.models.deberta_v2.modeling_deberta_v2 import DisentangledSelfAttention + + from QEfficient.transformers.models.pytorch_transforms import TextClassificationTransform + + try: + model, cfg = self._make_tiny_deberta() + except Exception as e: + pytest.skip(f"DeBERTa-v2 not available: {e}") + + assert any(isinstance(m, DisentangledSelfAttention) for m in model.modules()) + + transformed, applied = TextClassificationTransform.apply(model) + assert applied + + qeff_cls = TextClassificationTransform._module_mapping[DisentangledSelfAttention] + for m in transformed.modules(): + if type(m) is DisentangledSelfAttention: + pytest.fail("Found unreplaced DisentangledSelfAttention after transform") + + assert any(isinstance(m, qeff_cls) for m in transformed.modules()) + + def test_text_classification_transform_has_apply_method(self): + from QEfficient.transformers.models.pytorch_transforms import TextClassificationTransform + + assert hasattr(TextClassificationTransform, "apply") + assert callable(TextClassificationTransform.apply) + + +# --------------------------------------------------------------------------- +# Tests: BlockedKVAttentionTransform +# --------------------------------------------------------------------------- + + +@pytest.mark.transforms +class TestBlockedKVAttentionTransform: + """BlockedKVAttentionTransform must patch forward with num_kv_blocks parameter.""" + + def test_blocked_kv_transform_importable(self): + from QEfficient.transformers.models.pytorch_transforms import BlockedKVAttentionTransform + + assert BlockedKVAttentionTransform is not None + + def test_blocked_kv_transform_has_module_mapping(self): + from QEfficient.transformers.models.pytorch_transforms import BlockedKVAttentionTransform + + assert hasattr(BlockedKVAttentionTransform, "_module_mapping") + assert len(BlockedKVAttentionTransform._module_mapping) > 0 + + def test_blocked_kv_transform_contains_llama_attention(self): + from QEfficient.transformers.models.llama.modeling_llama import QEffLlamaAttention + from QEfficient.transformers.models.pytorch_transforms import BlockedKVAttentionTransform + + assert QEffLlamaAttention in BlockedKVAttentionTransform._module_mapping + + def test_blocked_kv_transform_has_apply_method(self): + from QEfficient.transformers.models.pytorch_transforms import BlockedKVAttentionTransform + + assert hasattr(BlockedKVAttentionTransform, "apply") + assert callable(BlockedKVAttentionTransform.apply) + + def test_blocked_kv_transform_applies_to_llama(self): + """BlockedKVAttentionTransform must apply to a KV-transformed Llama model.""" + from QEfficient.transformers.models.pytorch_transforms import BlockedKVAttentionTransform + + model = make_tiny_llama() + kv_model, _ = KVCacheTransform.apply(model) + transformed, applied = BlockedKVAttentionTransform.apply(kv_model, num_kv_blocks=4) + assert applied, "BlockedKVAttentionTransform must apply to KV-transformed Llama" + + def test_blocked_kv_transform_patches_forward(self): + """After BlockedKVAttentionTransform, attention forward must be patched.""" + from QEfficient.transformers.models.llama.modeling_llama import QEffLlamaAttention + from QEfficient.transformers.models.pytorch_transforms import BlockedKVAttentionTransform + + model = make_tiny_llama() + kv_model, _ = KVCacheTransform.apply(model) + BlockedKVAttentionTransform.apply(kv_model, num_kv_blocks=4) + + # After transform, attention modules should have patched forward + for m in kv_model.modules(): + if isinstance(m, QEffLlamaAttention): + # The forward should be a partial function with num_kv_blocks + assert hasattr(m, "forward"), "Attention module must have forward after transform" + break + + def test_blocked_kv_transform_returns_model_and_bool(self): + from QEfficient.transformers.models.pytorch_transforms import BlockedKVAttentionTransform + + model = make_tiny_llama() + kv_model, _ = KVCacheTransform.apply(model) + result = BlockedKVAttentionTransform.apply(kv_model, num_kv_blocks=4) + assert len(result) == 2 + assert isinstance(result[1], bool) + + def test_blocked_kv_transform_does_not_apply_to_gpt2(self): + """BlockedKVAttentionTransform must not apply to GPT2 (not in mapping).""" + from QEfficient.transformers.models.pytorch_transforms import BlockedKVAttentionTransform + + model = make_tiny_gpt2() + kv_model, _ = KVCacheTransform.apply(model) + _, applied = BlockedKVAttentionTransform.apply(kv_model, num_kv_blocks=4) + assert not applied, "BlockedKVAttentionTransform must not apply to GPT2" + + +# --------------------------------------------------------------------------- +# Tests: PrefillOnly transforms (structure only - GPT_OSS is external) +# --------------------------------------------------------------------------- + + +@pytest.mark.transforms +class TestPrefillOnlyTransformStructure: + """PrefillOnly transforms must have correct structure.""" + + def test_prefill_only_transform_importable(self): + from QEfficient.transformers.models.pytorch_transforms import PrefillOnlyTransform + + assert PrefillOnlyTransform is not None + + def test_prefill_only_chunked_transform_importable(self): + from QEfficient.transformers.models.pytorch_transforms import PrefillOnlyChunkedTransform + + assert PrefillOnlyChunkedTransform is not None + + def test_revert_prefill_only_transform_importable(self): + from QEfficient.transformers.models.pytorch_transforms import RevertPrefillOnlyTransform + + assert RevertPrefillOnlyTransform is not None + + def test_revert_prefill_keep_attention_transform_importable(self): + from QEfficient.transformers.models.pytorch_transforms import RevertPrefillKeepAttentionTransform + + assert RevertPrefillKeepAttentionTransform is not None + + def test_prefill_only_transform_has_module_mapping(self): + from QEfficient.transformers.models.pytorch_transforms import PrefillOnlyTransform + + assert hasattr(PrefillOnlyTransform, "_module_mapping") + assert len(PrefillOnlyTransform._module_mapping) > 0 + + def test_prefill_only_chunked_transform_has_module_mapping(self): + from QEfficient.transformers.models.pytorch_transforms import PrefillOnlyChunkedTransform + + assert hasattr(PrefillOnlyChunkedTransform, "_module_mapping") + assert len(PrefillOnlyChunkedTransform._module_mapping) > 0 + + def test_revert_prefill_only_transform_has_module_mapping(self): + from QEfficient.transformers.models.pytorch_transforms import RevertPrefillOnlyTransform + + assert hasattr(RevertPrefillOnlyTransform, "_module_mapping") + assert len(RevertPrefillOnlyTransform._module_mapping) > 0 + + def test_prefill_only_transform_maps_gpt_oss_model(self): + from QEfficient.transformers.models.gpt_oss.modeling_gpt_oss import ( + QEffGptOssModel, + QEffPrefillOnlyGptOssModel, + ) + from QEfficient.transformers.models.pytorch_transforms import PrefillOnlyTransform + + assert QEffGptOssModel in PrefillOnlyTransform._module_mapping + assert PrefillOnlyTransform._module_mapping[QEffGptOssModel] is QEffPrefillOnlyGptOssModel + + def test_prefill_only_transform_maps_gpt_oss_attention(self): + from QEfficient.transformers.models.gpt_oss.modeling_gpt_oss import ( + QEffGptOssAttention, + QEffPrefillOnlyGptOssAttention, + ) + from QEfficient.transformers.models.pytorch_transforms import PrefillOnlyTransform + + assert QEffGptOssAttention in PrefillOnlyTransform._module_mapping + assert PrefillOnlyTransform._module_mapping[QEffGptOssAttention] is QEffPrefillOnlyGptOssAttention + + def test_revert_prefill_only_is_inverse_of_prefill_only(self): + """RevertPrefillOnlyTransform must be the inverse of PrefillOnlyTransform for non-identity mappings.""" + from QEfficient.transformers.models.pytorch_transforms import ( + PrefillOnlyTransform, + RevertPrefillOnlyTransform, + ) + + # For each (src, dst) in PrefillOnlyTransform where src != dst, + # (dst, src) must be in RevertPrefillOnlyTransform + for src, dst in PrefillOnlyTransform._module_mapping.items(): + if src is dst: + continue # Skip identity mappings + assert dst in RevertPrefillOnlyTransform._module_mapping, ( + f"RevertPrefillOnlyTransform missing inverse mapping for {dst}" + ) + assert RevertPrefillOnlyTransform._module_mapping[dst] is src, ( + f"RevertPrefillOnlyTransform[{dst}] must be {src}" + ) + + def test_all_prefill_transforms_have_apply_method(self): + from QEfficient.transformers.models.pytorch_transforms import ( + PrefillOnlyChunkedTransform, + PrefillOnlyTransform, + RevertPrefillKeepAttentionTransform, + RevertPrefillOnlyTransform, + ) + + for cls in [ + PrefillOnlyTransform, + PrefillOnlyChunkedTransform, + RevertPrefillOnlyTransform, + RevertPrefillKeepAttentionTransform, + ]: + assert hasattr(cls, "apply"), f"{cls.__name__} missing apply method" + assert callable(cls.apply), f"{cls.__name__}.apply is not callable" + + +# --------------------------------------------------------------------------- +# Tests: VlmKVOffloadTransform (GAP D) +# --------------------------------------------------------------------------- + + +@pytest.mark.transforms +class TestVlmKVOffloadTransform: + """VlmKVOffloadTransform must be importable and have correct module mapping.""" + + def test_vlm_kv_offload_transform_importable(self): + from QEfficient.transformers.models.pytorch_transforms import VlmKVOffloadTransform + + assert VlmKVOffloadTransform is not None + + def test_vlm_kv_offload_transform_has_module_mapping(self): + from QEfficient.transformers.models.pytorch_transforms import VlmKVOffloadTransform + + assert hasattr(VlmKVOffloadTransform, "_module_mapping") + assert len(VlmKVOffloadTransform._module_mapping) > 0 + + def test_vlm_kv_offload_transform_maps_mllama_cross_attention(self): + from transformers.models.mllama.modeling_mllama import MllamaTextCrossAttention + + from QEfficient.transformers.models.pytorch_transforms import VlmKVOffloadTransform + + assert MllamaTextCrossAttention in VlmKVOffloadTransform._module_mapping + + def test_vlm_kv_offload_transform_maps_to_two_qpc_variant(self): + from transformers.models.mllama.modeling_mllama import MllamaTextCrossAttention + + from QEfficient.transformers.models.mllama.modeling_mllama import QEffMllamaTextCrossAttentionTwoQPC + from QEfficient.transformers.models.pytorch_transforms import VlmKVOffloadTransform + + assert VlmKVOffloadTransform._module_mapping[MllamaTextCrossAttention] is QEffMllamaTextCrossAttentionTwoQPC + + def test_vlm_kv_offload_transform_has_apply_method(self): + from QEfficient.transformers.models.pytorch_transforms import VlmKVOffloadTransform + + assert hasattr(VlmKVOffloadTransform, "apply") + assert callable(VlmKVOffloadTransform.apply) + + +# --------------------------------------------------------------------------- +# Tests: VlmNoKVOffloadTransform (GAP D) +# --------------------------------------------------------------------------- + + +@pytest.mark.transforms +class TestVlmNoKVOffloadTransform: + """VlmNoKVOffloadTransform must be importable and have correct module mapping.""" + + def test_vlm_no_kv_offload_transform_importable(self): + from QEfficient.transformers.models.pytorch_transforms import VlmNoKVOffloadTransform + + assert VlmNoKVOffloadTransform is not None + + def test_vlm_no_kv_offload_transform_has_module_mapping(self): + from QEfficient.transformers.models.pytorch_transforms import VlmNoKVOffloadTransform + + assert hasattr(VlmNoKVOffloadTransform, "_module_mapping") + assert len(VlmNoKVOffloadTransform._module_mapping) > 0 + + def test_vlm_no_kv_offload_transform_maps_mllama_cross_attention(self): + from transformers.models.mllama.modeling_mllama import MllamaTextCrossAttention + + from QEfficient.transformers.models.pytorch_transforms import VlmNoKVOffloadTransform + + assert MllamaTextCrossAttention in VlmNoKVOffloadTransform._module_mapping + + def test_vlm_no_kv_offload_transform_maps_to_single_qpc_variant(self): + from transformers.models.mllama.modeling_mllama import MllamaTextCrossAttention + + from QEfficient.transformers.models.mllama.modeling_mllama import QEffMllamaTextCrossAttentionSingleQPC + from QEfficient.transformers.models.pytorch_transforms import VlmNoKVOffloadTransform + + assert ( + VlmNoKVOffloadTransform._module_mapping[MllamaTextCrossAttention] is QEffMllamaTextCrossAttentionSingleQPC + ) + + def test_vlm_no_kv_offload_transform_has_apply_method(self): + from QEfficient.transformers.models.pytorch_transforms import VlmNoKVOffloadTransform + + assert hasattr(VlmNoKVOffloadTransform, "apply") + assert callable(VlmNoKVOffloadTransform.apply) + + def test_vlm_offload_and_no_offload_map_to_different_classes(self): + """VlmKVOffloadTransform and VlmNoKVOffloadTransform must map to different QEff classes.""" + from transformers.models.mllama.modeling_mllama import MllamaTextCrossAttention + + from QEfficient.transformers.models.pytorch_transforms import ( + VlmKVOffloadTransform, + VlmNoKVOffloadTransform, + ) + + offload_cls = VlmKVOffloadTransform._module_mapping[MllamaTextCrossAttention] + no_offload_cls = VlmNoKVOffloadTransform._module_mapping[MllamaTextCrossAttention] + assert offload_cls is not no_offload_cls, ( + "VlmKVOffloadTransform and VlmNoKVOffloadTransform must map to different classes" + ) + + +# --------------------------------------------------------------------------- +# Tests: KVCacheExternalModuleMapperTransform (GAP D) +# --------------------------------------------------------------------------- + + +@pytest.mark.transforms +class TestKVCacheExternalModuleMapperTransform: + """KVCacheExternalModuleMapperTransform must have correct string-based mappings.""" + + def test_external_mapper_transform_importable(self): + from QEfficient.transformers.models.pytorch_transforms import KVCacheExternalModuleMapperTransform + + assert KVCacheExternalModuleMapperTransform is not None + + def test_external_mapper_has_match_string_replace_method(self): + from QEfficient.transformers.models.pytorch_transforms import KVCacheExternalModuleMapperTransform + + assert hasattr(KVCacheExternalModuleMapperTransform, "_match_string_replace_method") + assert isinstance(KVCacheExternalModuleMapperTransform._match_string_replace_method, dict) + + def test_external_mapper_contains_internvl(self): + from QEfficient.transformers.models.pytorch_transforms import KVCacheExternalModuleMapperTransform + + assert "InternVLChatModel" in KVCacheExternalModuleMapperTransform._match_string_replace_method + + def test_external_mapper_contains_molmo(self): + from QEfficient.transformers.models.pytorch_transforms import KVCacheExternalModuleMapperTransform + + assert "MolmoForCausalLM" in KVCacheExternalModuleMapperTransform._match_string_replace_method + + def test_external_mapper_contains_grok1(self): + from QEfficient.transformers.models.pytorch_transforms import KVCacheExternalModuleMapperTransform + + assert "Grok1ModelForCausalLM" in KVCacheExternalModuleMapperTransform._match_string_replace_method + + def test_external_mapper_internvl_has_forward(self): + from QEfficient.transformers.models.pytorch_transforms import KVCacheExternalModuleMapperTransform + + internvl_mapping = KVCacheExternalModuleMapperTransform._match_string_replace_method["InternVLChatModel"] + assert "forward" in internvl_mapping + assert callable(internvl_mapping["forward"]) + + def test_external_mapper_molmo_has_forward(self): + from QEfficient.transformers.models.pytorch_transforms import KVCacheExternalModuleMapperTransform + + molmo_mapping = KVCacheExternalModuleMapperTransform._match_string_replace_method["MolmoForCausalLM"] + assert "forward" in molmo_mapping + assert callable(molmo_mapping["forward"]) + + def test_external_mapper_grok1_has_forward(self): + from QEfficient.transformers.models.pytorch_transforms import KVCacheExternalModuleMapperTransform + + grok1_mapping = KVCacheExternalModuleMapperTransform._match_string_replace_method["Grok1ModelForCausalLM"] + assert "forward" in grok1_mapping + assert callable(grok1_mapping["forward"]) + + def test_external_mapper_has_apply_method(self): + from QEfficient.transformers.models.pytorch_transforms import KVCacheExternalModuleMapperTransform + + assert hasattr(KVCacheExternalModuleMapperTransform, "apply") + assert callable(KVCacheExternalModuleMapperTransform.apply) + + def test_external_mapper_internvl_has_get_dummy_inputs(self): + from QEfficient.transformers.models.pytorch_transforms import KVCacheExternalModuleMapperTransform + + internvl_mapping = KVCacheExternalModuleMapperTransform._match_string_replace_method["InternVLChatModel"] + assert "get_dummy_inputs" in internvl_mapping + assert callable(internvl_mapping["get_dummy_inputs"]) + + def test_external_mapper_rms_norm_has_forward(self): + """RMSLayerNorm must be mapped to CustomRMSNormAIC.forward.""" + from QEfficient.customop import CustomRMSNormAIC + from QEfficient.transformers.models.pytorch_transforms import KVCacheExternalModuleMapperTransform + + assert "RMSLayerNorm" in KVCacheExternalModuleMapperTransform._match_string_replace_method + rms_mapping = KVCacheExternalModuleMapperTransform._match_string_replace_method["RMSLayerNorm"] + assert rms_mapping["forward"] is CustomRMSNormAIC.forward diff --git a/tests/unit_test/utils/__init__.py b/tests/unit_test/utils/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/unit_test/utils/test_auto_model_api.py b/tests/unit_test/utils/test_auto_model_api.py new file mode 100644 index 000000000..ae2a1d722 --- /dev/null +++ b/tests/unit_test/utils/test_auto_model_api.py @@ -0,0 +1,660 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +# SPDX-License-Identifier: BSD-3-Clause +# +# ----------------------------------------------------------------------------- +""" +Tests for QEFFAutoModel API surface in QEfficient. + +Tests verify: + - QEFFAutoModelForCausalLM wraps models correctly + - is_tlm property is False by default + - build_prefill_specialization returns dict with correct keys + - build_decode_specialization returns dict with correct keys + - check_and_get_num_speculative_tokens returns None for non-TLM + - prefill() method exists + - QEFFAutoModel (encoder) wraps BERT correctly + - QEFFAutoModelForCTC wraps Wav2Vec2 correctly + +All tests run on CPU only, using tiny in-memory models. +""" + +import pytest +import torch +from transformers import GPT2Config, GPT2LMHeadModel + +# --------------------------------------------------------------------------- +# Tiny model factories +# --------------------------------------------------------------------------- + + +def make_tiny_gpt2(): + cfg = GPT2Config(n_layer=1, n_head=2, n_embd=64, vocab_size=500, n_positions=32, n_ctx=32) + return GPT2LMHeadModel(cfg).eval() + + +def make_tiny_llama(): + from transformers import LlamaConfig, LlamaForCausalLM + + cfg = LlamaConfig( + num_hidden_layers=1, + num_attention_heads=2, + num_key_value_heads=2, + hidden_size=64, + intermediate_size=128, + vocab_size=500, + max_position_embeddings=32, + ) + return LlamaForCausalLM(cfg).eval() + + +def make_tiny_bert(): + from transformers import BertConfig, BertModel + + cfg = BertConfig( + num_hidden_layers=1, + num_attention_heads=2, + hidden_size=64, + intermediate_size=128, + vocab_size=500, + max_position_embeddings=64, + ) + return BertModel(cfg).eval() + + +# --------------------------------------------------------------------------- +# Tests: QEFFAutoModelForCausalLM basic wrapping +# --------------------------------------------------------------------------- + + +@pytest.mark.cpu_only +class TestQEFFAutoModelForCausalLMBasic: + """QEFFAutoModelForCausalLM must wrap models and expose correct attributes.""" + + def test_wraps_gpt2_model(self): + """QEFFAutoModelForCausalLM must wrap a GPT2LMHeadModel.""" + from QEfficient.transformers.models.modeling_auto import QEFFAutoModelForCausalLM + + model = make_tiny_gpt2() + qeff = QEFFAutoModelForCausalLM(model) + assert qeff is not None + + def test_wraps_llama_model(self): + """QEFFAutoModelForCausalLM must wrap a LlamaForCausalLM.""" + from QEfficient.transformers.models.modeling_auto import QEFFAutoModelForCausalLM + + model = make_tiny_llama() + qeff = QEFFAutoModelForCausalLM(model) + assert qeff is not None + + def test_is_tlm_false_by_default(self): + """is_tlm must be False when no SpD config is provided.""" + from QEfficient.transformers.models.modeling_auto import QEFFAutoModelForCausalLM + + model = make_tiny_gpt2() + qeff = QEFFAutoModelForCausalLM(model) + assert qeff.is_tlm is False + + def test_has_prefill_method(self): + """QEFFAutoModelForCausalLM must have a prefill() method.""" + from QEfficient.transformers.models.modeling_auto import QEFFAutoModelForCausalLM + + assert hasattr(QEFFAutoModelForCausalLM, "prefill") + assert callable(QEFFAutoModelForCausalLM.prefill) + + def test_has_export_method(self): + """QEFFAutoModelForCausalLM must have an export() method.""" + from QEfficient.transformers.models.modeling_auto import QEFFAutoModelForCausalLM + + assert hasattr(QEFFAutoModelForCausalLM, "export") + assert callable(QEFFAutoModelForCausalLM.export) + + def test_has_check_and_get_num_speculative_tokens(self): + """QEFFAutoModelForCausalLM must have check_and_get_num_speculative_tokens.""" + from QEfficient.transformers.models.modeling_auto import QEFFAutoModelForCausalLM + + assert hasattr(QEFFAutoModelForCausalLM, "check_and_get_num_speculative_tokens") + assert callable(QEFFAutoModelForCausalLM.check_and_get_num_speculative_tokens) + + def test_has_build_prefill_specialization(self): + """QEFFAutoModelForCausalLM must have build_prefill_specialization.""" + from QEfficient.transformers.models.modeling_auto import QEFFAutoModelForCausalLM + + assert hasattr(QEFFAutoModelForCausalLM, "build_prefill_specialization") + assert callable(QEFFAutoModelForCausalLM.build_prefill_specialization) + + def test_has_build_decode_specialization(self): + """QEFFAutoModelForCausalLM must have build_decode_specialization.""" + from QEfficient.transformers.models.modeling_auto import QEFFAutoModelForCausalLM + + assert hasattr(QEFFAutoModelForCausalLM, "build_decode_specialization") + assert callable(QEFFAutoModelForCausalLM.build_decode_specialization) + + +# --------------------------------------------------------------------------- +# Tests: QEFFAutoModelForCausalLM specialization API +# --------------------------------------------------------------------------- + + +@pytest.mark.cpu_only +class TestQEFFAutoModelForCausalLMSpecializations: + """build_prefill_specialization and build_decode_specialization must return correct dicts.""" + + def _make_qeff(self): + from QEfficient.transformers.models.modeling_auto import QEFFAutoModelForCausalLM + + return QEFFAutoModelForCausalLM(make_tiny_gpt2()) + + def test_build_prefill_specialization_returns_dict(self): + """build_prefill_specialization must return a dict.""" + qeff = self._make_qeff() + result = qeff.build_prefill_specialization(prefill_seq_len=8, ctx_len=32, batch_size=1, full_batch_size=None) + assert isinstance(result, dict), f"Expected dict, got {type(result)}" + + def test_build_prefill_specialization_has_seq_len_key(self): + """build_prefill_specialization dict must contain 'seq_len'.""" + qeff = self._make_qeff() + result = qeff.build_prefill_specialization(prefill_seq_len=8, ctx_len=32, batch_size=1, full_batch_size=None) + assert "seq_len" in result, f"'seq_len' not in prefill spec: {result}" + + def test_build_prefill_specialization_has_ctx_len_key(self): + """build_prefill_specialization dict must contain 'ctx_len'.""" + qeff = self._make_qeff() + result = qeff.build_prefill_specialization(prefill_seq_len=8, ctx_len=32, batch_size=1, full_batch_size=None) + assert "ctx_len" in result, f"'ctx_len' not in prefill spec: {result}" + + def test_build_prefill_specialization_seq_len_matches_input(self): + """build_prefill_specialization seq_len must match the input prefill_seq_len.""" + qeff = self._make_qeff() + result = qeff.build_prefill_specialization(prefill_seq_len=16, ctx_len=64, batch_size=1, full_batch_size=None) + assert result["seq_len"] == 16, f"Expected seq_len=16, got {result['seq_len']}" + + def test_build_prefill_specialization_ctx_len_matches_input(self): + """build_prefill_specialization ctx_len must match the input ctx_len.""" + qeff = self._make_qeff() + result = qeff.build_prefill_specialization(prefill_seq_len=8, ctx_len=64, batch_size=1, full_batch_size=None) + assert result["ctx_len"] == 64, f"Expected ctx_len=64, got {result['ctx_len']}" + + def test_build_decode_specialization_returns_dict(self): + """build_decode_specialization must return a dict.""" + qeff = self._make_qeff() + result = qeff.build_decode_specialization(ctx_len=32, batch_size=1, full_batch_size=None) + assert isinstance(result, dict), f"Expected dict, got {type(result)}" + + def test_build_decode_specialization_has_seq_len_key(self): + """build_decode_specialization dict must contain 'seq_len'.""" + qeff = self._make_qeff() + result = qeff.build_decode_specialization(ctx_len=32, batch_size=1, full_batch_size=None) + assert "seq_len" in result, f"'seq_len' not in decode spec: {result}" + + def test_build_decode_specialization_has_ctx_len_key(self): + """build_decode_specialization dict must contain 'ctx_len'.""" + qeff = self._make_qeff() + result = qeff.build_decode_specialization(ctx_len=32, batch_size=1, full_batch_size=None) + assert "ctx_len" in result, f"'ctx_len' not in decode spec: {result}" + + def test_build_decode_specialization_seq_len_is_1(self): + """build_decode_specialization seq_len must be 1 (decode step).""" + qeff = self._make_qeff() + result = qeff.build_decode_specialization(ctx_len=32, batch_size=1, full_batch_size=None) + assert result["seq_len"] == 1, f"Expected seq_len=1 for decode, got {result['seq_len']}" + + def test_build_decode_specialization_ctx_len_matches_input(self): + """build_decode_specialization ctx_len must match the input ctx_len.""" + qeff = self._make_qeff() + result = qeff.build_decode_specialization(ctx_len=64, batch_size=1, full_batch_size=None) + assert result["ctx_len"] == 64, f"Expected ctx_len=64, got {result['ctx_len']}" + + def test_check_and_get_num_speculative_tokens_returns_none_for_non_tlm(self): + """For non-TLM model, check_and_get_num_speculative_tokens must return None.""" + qeff = self._make_qeff() + result = qeff.check_and_get_num_speculative_tokens(num_speculative_tokens=None, prefill_seq_len=1) + assert result is None, f"Expected None for non-TLM, got {result}" + + def test_build_decode_specialization_with_num_speculative_tokens(self): + """build_decode_specialization with num_speculative_tokens must include it in result.""" + qeff = self._make_qeff() + result = qeff.build_decode_specialization( + ctx_len=32, batch_size=1, full_batch_size=None, num_speculative_tokens=3 + ) + assert isinstance(result, dict) + # The result should reflect the speculative tokens in some way + assert "ctx_len" in result + + +# --------------------------------------------------------------------------- +# Tests: QEFFAutoModelForCausalLM prefill toggle +# --------------------------------------------------------------------------- + + +@pytest.mark.cpu_only +class TestQEFFAutoModelForCausalLMPrefillToggle: + """prefill() method must exist and be callable.""" + + def test_prefill_method_is_callable(self): + """QEFFAutoModelForCausalLM.prefill must be callable.""" + from QEfficient.transformers.models.modeling_auto import QEFFAutoModelForCausalLM + + assert callable(QEFFAutoModelForCausalLM.prefill) + + def test_prefill_method_accepts_enable_parameter(self): + """prefill() must accept an 'enable' parameter.""" + import inspect + + from QEfficient.transformers.models.modeling_auto import QEFFAutoModelForCausalLM + + sig = inspect.signature(QEFFAutoModelForCausalLM.prefill) + assert "enable" in sig.parameters, f"prefill() must have 'enable' parameter, got: {list(sig.parameters.keys())}" + + def test_prefill_method_accepts_enable_chunking_parameter(self): + """prefill() must accept an 'enable_chunking' parameter.""" + import inspect + + from QEfficient.transformers.models.modeling_auto import QEFFAutoModelForCausalLM + + sig = inspect.signature(QEFFAutoModelForCausalLM.prefill) + assert "enable_chunking" in sig.parameters, ( + f"prefill() must have 'enable_chunking' parameter, got: {list(sig.parameters.keys())}" + ) + + +# --------------------------------------------------------------------------- +# Tests: QEFFAutoModel (encoder) +# --------------------------------------------------------------------------- + + +@pytest.mark.cpu_only +class TestQEFFAutoModelEncoder: + """QEFFAutoModel must wrap encoder-only models like BERT.""" + + def test_qeff_auto_model_is_importable(self): + """QEFFAutoModel must be importable.""" + from QEfficient.transformers.models.modeling_auto import QEFFAutoModel + + assert QEFFAutoModel is not None + + def test_qeff_auto_model_wraps_bert(self): + """QEFFAutoModel must wrap a BertModel.""" + from QEfficient.transformers.models.modeling_auto import QEFFAutoModel + + model = make_tiny_bert() + qeff = QEFFAutoModel(model) + assert qeff is not None + + def test_qeff_auto_model_has_export_method(self): + """QEFFAutoModel must have an export() method.""" + from QEfficient.transformers.models.modeling_auto import QEFFAutoModel + + assert hasattr(QEFFAutoModel, "export") + assert callable(QEFFAutoModel.export) + + def test_qeff_auto_model_forward_produces_finite_hidden_states(self): + """QEFFAutoModel forward must produce finite hidden states.""" + from QEfficient.transformers.models.modeling_auto import QEFFAutoModel + + model = make_tiny_bert() + qeff = QEFFAutoModel(model) + + input_ids = torch.randint(0, 500, (1, 16)) + attention_mask = torch.ones(1, 16, dtype=torch.long) + + with torch.no_grad(): + output = qeff.model(input_ids=input_ids, attention_mask=attention_mask) + + assert torch.isfinite(output.last_hidden_state).all(), "QEFFAutoModel forward must produce finite hidden states" + + +# --------------------------------------------------------------------------- +# Tests: QEFFAutoModelForCTC +# --------------------------------------------------------------------------- + + +@pytest.mark.cpu_only +class TestQEFFAutoModelForCTC: + """QEFFAutoModelForCTC must be importable and wrap CTC models.""" + + def test_qeff_auto_model_for_ctc_is_importable(self): + """QEFFAutoModelForCTC must be importable.""" + from QEfficient.transformers.models.modeling_auto import QEFFAutoModelForCTC + + assert QEFFAutoModelForCTC is not None + + def test_qeff_auto_model_for_ctc_has_export_method(self): + """QEFFAutoModelForCTC must have an export() method.""" + from QEfficient.transformers.models.modeling_auto import QEFFAutoModelForCTC + + assert hasattr(QEFFAutoModelForCTC, "export") + assert callable(QEFFAutoModelForCTC.export) + + def test_qeff_auto_model_for_ctc_class_attributes(self): + """QEFFAutoModelForCTC must have expected class attributes.""" + from QEfficient.transformers.models.modeling_auto import QEFFAutoModelForCTC + + # Must have _pytorch_transforms or similar + assert hasattr(QEFFAutoModelForCTC, "_pytorch_transforms") or hasattr( + QEFFAutoModelForCTC, "_onnx_transforms" + ), "QEFFAutoModelForCTC must have transform attributes" + + +# --------------------------------------------------------------------------- +# Tests: QEFFAutoModelForSequenceClassification +# --------------------------------------------------------------------------- + + +@pytest.mark.cpu_only +class TestQEFFAutoModelForSequenceClassification: + """QEFFAutoModelForSequenceClassification must be importable.""" + + def test_importable(self): + from QEfficient.transformers.models.modeling_auto import QEFFAutoModelForSequenceClassification + + assert QEFFAutoModelForSequenceClassification is not None + + def test_has_export_method(self): + from QEfficient.transformers.models.modeling_auto import QEFFAutoModelForSequenceClassification + + assert hasattr(QEFFAutoModelForSequenceClassification, "export") + + def test_wraps_bert_for_sequence_classification(self): + """QEFFAutoModelForSequenceClassification must wrap BertForSequenceClassification.""" + from transformers import BertConfig, BertForSequenceClassification + + from QEfficient.transformers.models.modeling_auto import QEFFAutoModelForSequenceClassification + + cfg = BertConfig( + num_hidden_layers=1, + num_attention_heads=2, + hidden_size=64, + intermediate_size=128, + vocab_size=500, + max_position_embeddings=64, + num_labels=3, + ) + model = BertForSequenceClassification(cfg).eval() + qeff = QEFFAutoModelForSequenceClassification(model) + assert qeff is not None + + +# --------------------------------------------------------------------------- +# Tests: QEFFAutoModelForSpeechSeq2Seq +# --------------------------------------------------------------------------- + + +@pytest.mark.cpu_only +class TestQEFFAutoModelForSpeechSeq2Seq: + """QEFFAutoModelForSpeechSeq2Seq must be importable.""" + + def test_importable(self): + from QEfficient.transformers.models.modeling_auto import QEFFAutoModelForSpeechSeq2Seq + + assert QEFFAutoModelForSpeechSeq2Seq is not None + + def test_has_export_method(self): + from QEfficient.transformers.models.modeling_auto import QEFFAutoModelForSpeechSeq2Seq + + assert hasattr(QEFFAutoModelForSpeechSeq2Seq, "export") + + +# --------------------------------------------------------------------------- +# Tests: QEFFAutoModelForCausalLM model registry +# --------------------------------------------------------------------------- + + +@pytest.mark.cpu_only +class TestQEFFAutoModelRegistry: + """QEFFAutoModelForCausalLM must have correct model registry.""" + + def test_has_pytorch_transforms_list(self): + """QEFFAutoModelForCausalLM must have _pytorch_transforms list.""" + from QEfficient.transformers.models.modeling_auto import QEFFAutoModelForCausalLM + + assert hasattr(QEFFAutoModelForCausalLM, "_pytorch_transforms") + assert isinstance(QEFFAutoModelForCausalLM._pytorch_transforms, list) + + def test_pytorch_transforms_contains_kv_cache_transform(self): + """_pytorch_transforms must contain KVCacheTransform.""" + from QEfficient.transformers.models.modeling_auto import QEFFAutoModelForCausalLM + from QEfficient.transformers.models.pytorch_transforms import KVCacheTransform + + assert KVCacheTransform in QEFFAutoModelForCausalLM._pytorch_transforms + + def test_pytorch_transforms_contains_custom_ops_transform(self): + """_pytorch_transforms must contain CustomOpsTransform.""" + from QEfficient.transformers.models.modeling_auto import QEFFAutoModelForCausalLM + from QEfficient.transformers.models.pytorch_transforms import CustomOpsTransform + + assert CustomOpsTransform in QEFFAutoModelForCausalLM._pytorch_transforms + + def test_has_onnx_transforms_list(self): + """QEFFAutoModelForCausalLM must have _onnx_transforms list.""" + from QEfficient.transformers.models.modeling_auto import QEFFAutoModelForCausalLM + + assert hasattr(QEFFAutoModelForCausalLM, "_onnx_transforms") + assert isinstance(QEFFAutoModelForCausalLM._onnx_transforms, list) + + def test_onnx_transforms_contains_fp16_clip(self): + """_onnx_transforms must contain FP16ClipTransform.""" + from QEfficient.base.onnx_transforms import FP16ClipTransform + from QEfficient.transformers.models.modeling_auto import QEFFAutoModelForCausalLM + + assert FP16ClipTransform in QEFFAutoModelForCausalLM._onnx_transforms + + def test_onnx_transforms_contains_split_tensors(self): + """_onnx_transforms must contain SplitTensorsTransform.""" + from QEfficient.base.onnx_transforms import SplitTensorsTransform + from QEfficient.transformers.models.modeling_auto import QEFFAutoModelForCausalLM + + assert SplitTensorsTransform in QEFFAutoModelForCausalLM._onnx_transforms + + +# --------------------------------------------------------------------------- +# Tests: QEFFAutoModelForCausalLM CCL mode (GAP F) +# --------------------------------------------------------------------------- + + +@pytest.mark.cpu_only +class TestQEFFAutoModelForCausalLMCCL: + """CCL specialization methods must include comp_ctx_lengths in the result.""" + + def _make_qeff(self): + from QEfficient.transformers.models.modeling_auto import QEFFAutoModelForCausalLM + + return QEFFAutoModelForCausalLM(make_tiny_gpt2()) + + def test_build_prefill_specialization_with_ccl_returns_dict(self): + """build_prefill_specialization with comp_ctx_lengths must return a dict.""" + qeff = self._make_qeff() + result = qeff.build_prefill_specialization( + prefill_seq_len=8, + ctx_len=32, + batch_size=1, + full_batch_size=None, + comp_ctx_lengths=[16, 32], + ) + assert isinstance(result, dict), f"build_prefill_specialization with CCL must return dict, got {type(result)}" + + def test_build_decode_specialization_with_ccl_returns_dict(self): + """build_decode_specialization with comp_ctx_lengths must return a dict.""" + qeff = self._make_qeff() + result = qeff.build_decode_specialization( + ctx_len=32, + batch_size=1, + full_batch_size=None, + comp_ctx_lengths=[16, 32], + ) + assert isinstance(result, dict), f"build_decode_specialization with CCL must return dict, got {type(result)}" + + def test_build_prefill_specialization_ccl_result_has_comp_ctx_lengths_key(self): + """build_prefill_specialization with CCL must include 'comp_ctx_lengths' in result.""" + qeff = self._make_qeff() + result = qeff.build_prefill_specialization( + prefill_seq_len=8, + ctx_len=32, + batch_size=1, + full_batch_size=None, + comp_ctx_lengths=[16, 32], + ) + assert "comp_ctx_lengths" in result, f"CCL prefill spec must have 'comp_ctx_lengths' key: {result}" + + def test_build_decode_specialization_ccl_result_has_comp_ctx_lengths_key(self): + """build_decode_specialization with CCL must include 'comp_ctx_lengths' in result.""" + qeff = self._make_qeff() + result = qeff.build_decode_specialization( + ctx_len=32, + batch_size=1, + full_batch_size=None, + comp_ctx_lengths=[16, 32], + ) + assert "comp_ctx_lengths" in result, f"CCL decode spec must have 'comp_ctx_lengths' key: {result}" + + def test_build_prefill_specialization_ccl_preserves_comp_ctx_lengths_values(self): + """build_prefill_specialization must preserve the comp_ctx_lengths values.""" + qeff = self._make_qeff() + comp_ctx_lengths = [16, 32] + result = qeff.build_prefill_specialization( + prefill_seq_len=8, + ctx_len=32, + batch_size=1, + full_batch_size=None, + comp_ctx_lengths=comp_ctx_lengths, + ) + assert result["comp_ctx_lengths"] == comp_ctx_lengths, ( + f"Expected comp_ctx_lengths={comp_ctx_lengths}, got {result['comp_ctx_lengths']}" + ) + + def test_build_decode_specialization_ccl_preserves_comp_ctx_lengths_values(self): + """build_decode_specialization must preserve the comp_ctx_lengths values.""" + qeff = self._make_qeff() + comp_ctx_lengths = [16, 32] + result = qeff.build_decode_specialization( + ctx_len=32, + batch_size=1, + full_batch_size=None, + comp_ctx_lengths=comp_ctx_lengths, + ) + assert result["comp_ctx_lengths"] == comp_ctx_lengths, ( + f"Expected comp_ctx_lengths={comp_ctx_lengths}, got {result['comp_ctx_lengths']}" + ) + + def test_build_prefill_specialization_ccl_still_has_ctx_len(self): + """build_prefill_specialization with CCL must still have 'ctx_len' key.""" + qeff = self._make_qeff() + result = qeff.build_prefill_specialization( + prefill_seq_len=8, + ctx_len=32, + batch_size=1, + full_batch_size=None, + comp_ctx_lengths=[16, 32], + ) + assert "ctx_len" in result, f"CCL prefill spec must still have 'ctx_len': {result}" + + def test_build_decode_specialization_ccl_still_has_ctx_len(self): + """build_decode_specialization with CCL must still have 'ctx_len' key.""" + qeff = self._make_qeff() + result = qeff.build_decode_specialization( + ctx_len=32, + batch_size=1, + full_batch_size=None, + comp_ctx_lengths=[16, 32], + ) + assert "ctx_len" in result, f"CCL decode spec must still have 'ctx_len': {result}" + + +# --------------------------------------------------------------------------- +# Tests: QEFFAutoModelForCausalLM prefill state change (GAP F) +# --------------------------------------------------------------------------- + + +@pytest.mark.cpu_only +class TestQEFFAutoModelForCausalLMPrefillStateChange: + """prefill() method and PrefillOnlyTransform must have correct structure.""" + + def _make_qeff(self): + from QEfficient.transformers.models.modeling_auto import QEFFAutoModelForCausalLM + + return QEFFAutoModelForCausalLM(make_tiny_gpt2()) + + def test_prefill_method_is_callable(self): + """prefill() must be callable.""" + qeff = self._make_qeff() + assert callable(qeff.prefill) + + def test_prefill_method_accepts_enable_parameter(self): + """prefill() must accept an 'enable' parameter.""" + import inspect + + from QEfficient.transformers.models.modeling_auto import QEFFAutoModelForCausalLM + + sig = inspect.signature(QEFFAutoModelForCausalLM.prefill) + assert "enable" in sig.parameters + + def test_prefill_method_accepts_enable_chunking_parameter(self): + """prefill() must accept an 'enable_chunking' parameter.""" + import inspect + + from QEfficient.transformers.models.modeling_auto import QEFFAutoModelForCausalLM + + sig = inspect.signature(QEFFAutoModelForCausalLM.prefill) + assert "enable_chunking" in sig.parameters + + def test_prefill_method_accepts_retain_full_kv_parameter(self): + """prefill() must accept a 'retain_full_kv' parameter.""" + import inspect + + from QEfficient.transformers.models.modeling_auto import QEFFAutoModelForCausalLM + + sig = inspect.signature(QEFFAutoModelForCausalLM.prefill) + assert "retain_full_kv" in sig.parameters + + def test_prefill_only_transform_importable(self): + """PrefillOnlyTransform must be importable.""" + from QEfficient.transformers.models.pytorch_transforms import PrefillOnlyTransform + + assert PrefillOnlyTransform is not None + + def test_prefill_only_transform_has_module_mapping(self): + """PrefillOnlyTransform must have a _module_mapping.""" + from QEfficient.transformers.models.pytorch_transforms import PrefillOnlyTransform + + assert hasattr(PrefillOnlyTransform, "_module_mapping") + assert isinstance(PrefillOnlyTransform._module_mapping, dict) + assert len(PrefillOnlyTransform._module_mapping) > 0 + + def test_revert_prefill_only_transform_importable(self): + """RevertPrefillOnlyTransform must be importable.""" + from QEfficient.transformers.models.pytorch_transforms import RevertPrefillOnlyTransform + + assert RevertPrefillOnlyTransform is not None + + def test_revert_prefill_only_transform_has_module_mapping(self): + """RevertPrefillOnlyTransform must have a _module_mapping.""" + from QEfficient.transformers.models.pytorch_transforms import RevertPrefillOnlyTransform + + assert hasattr(RevertPrefillOnlyTransform, "_module_mapping") + assert isinstance(RevertPrefillOnlyTransform._module_mapping, dict) + assert len(RevertPrefillOnlyTransform._module_mapping) > 0 + + def test_prefill_only_transform_maps_to_prefill_variants(self): + """PrefillOnlyTransform _module_mapping values must be prefill-only variants.""" + from QEfficient.transformers.models.pytorch_transforms import PrefillOnlyTransform + + for src_cls, dst_cls in PrefillOnlyTransform._module_mapping.items(): + dst_name = dst_cls.__name__ + assert "Prefill" in dst_name or "prefill" in dst_name.lower(), ( + f"PrefillOnlyTransform maps {src_cls.__name__} -> {dst_name}, " + f"but destination should be a prefill variant" + ) + + def test_prefill_only_chunked_transform_importable(self): + """PrefillOnlyChunkedTransform must be importable.""" + from QEfficient.transformers.models.pytorch_transforms import PrefillOnlyChunkedTransform + + assert PrefillOnlyChunkedTransform is not None + + def test_prefill_only_chunked_transform_has_module_mapping(self): + """PrefillOnlyChunkedTransform must have a _module_mapping.""" + from QEfficient.transformers.models.pytorch_transforms import PrefillOnlyChunkedTransform + + assert hasattr(PrefillOnlyChunkedTransform, "_module_mapping") + assert isinstance(PrefillOnlyChunkedTransform._module_mapping, dict) diff --git a/tests/unit_test/utils/test_cloud.py b/tests/unit_test/utils/test_cloud.py new file mode 100644 index 000000000..264942970 --- /dev/null +++ b/tests/unit_test/utils/test_cloud.py @@ -0,0 +1,1234 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +# SPDX-License-Identifier: BSD-3-Clause +# +# ----------------------------------------------------------------------------- +""" +CPU-only tests for QEfficient.cloud module. + +Tests verify: + - Module importability + - Argument parsing for CLI scripts (compile.py, execute.py, export.py, infer.py) + - Function signatures and parameter validation + - Error handling for missing required arguments + - finetune.py helper functions (setup_seeds, apply_peft, etc.) + +All tests run on CPU only. No actual compilation, execution, or model loading +is performed - only argument parsing and function structure validation. +""" + +import argparse +import inspect +from unittest.mock import MagicMock + +import pytest + +# --------------------------------------------------------------------------- +# Tests: Module importability +# --------------------------------------------------------------------------- + + +class TestCloudModuleImportability: + """All cloud modules must be importable on CPU.""" + + def test_cloud_init_importable(self): + import QEfficient.cloud + + assert QEfficient.cloud is not None + + def test_compile_module_importable(self): + import QEfficient.cloud.compile + + assert QEfficient.cloud.compile is not None + + def test_execute_module_importable(self): + import QEfficient.cloud.execute + + assert QEfficient.cloud.execute is not None + + def test_export_module_importable(self): + import QEfficient.cloud.export + + assert QEfficient.cloud.export is not None + + def test_infer_module_importable(self): + import QEfficient.cloud.infer + + assert QEfficient.cloud.infer is not None + + def test_finetune_module_importable(self): + import QEfficient.cloud.finetune + + assert QEfficient.cloud.finetune is not None + + def test_finetune_experimental_importable(self): + import QEfficient.cloud.finetune_experimental + + assert QEfficient.cloud.finetune_experimental is not None + + +# --------------------------------------------------------------------------- +# Tests: export.py - function signatures +# --------------------------------------------------------------------------- + + +class TestExportFunctionSignatures: + """export.py functions must have correct signatures.""" + + def test_get_onnx_path_exists(self): + from QEfficient.cloud.export import get_onnx_path_and_setup_customIO + + assert callable(get_onnx_path_and_setup_customIO) + + def test_get_onnx_path_has_model_name(self): + from QEfficient.cloud.export import get_onnx_path_and_setup_customIO + + sig = inspect.signature(get_onnx_path_and_setup_customIO) + assert "model_name" in sig.parameters + + def test_get_onnx_path_has_cache_dir(self): + from QEfficient.cloud.export import get_onnx_path_and_setup_customIO + + sig = inspect.signature(get_onnx_path_and_setup_customIO) + assert "cache_dir" in sig.parameters + + def test_get_onnx_path_has_hf_token(self): + from QEfficient.cloud.export import get_onnx_path_and_setup_customIO + + sig = inspect.signature(get_onnx_path_and_setup_customIO) + assert "hf_token" in sig.parameters + + def test_get_onnx_path_has_full_batch_size(self): + from QEfficient.cloud.export import get_onnx_path_and_setup_customIO + + sig = inspect.signature(get_onnx_path_and_setup_customIO) + assert "full_batch_size" in sig.parameters + + def test_get_onnx_path_has_local_model_dir(self): + from QEfficient.cloud.export import get_onnx_path_and_setup_customIO + + sig = inspect.signature(get_onnx_path_and_setup_customIO) + assert "local_model_dir" in sig.parameters + + def test_get_onnx_path_has_mxint8_kv_cache(self): + from QEfficient.cloud.export import get_onnx_path_and_setup_customIO + + sig = inspect.signature(get_onnx_path_and_setup_customIO) + assert "mxint8_kv_cache" in sig.parameters + + def test_export_main_exists(self): + from QEfficient.cloud.export import main + + assert callable(main) + + def test_export_main_has_model_name(self): + from QEfficient.cloud.export import main + + sig = inspect.signature(main) + assert "model_name" in sig.parameters + + def test_export_main_has_cache_dir(self): + from QEfficient.cloud.export import main + + sig = inspect.signature(main) + assert "cache_dir" in sig.parameters + + def test_export_main_has_hf_token(self): + from QEfficient.cloud.export import main + + sig = inspect.signature(main) + assert "hf_token" in sig.parameters + + def test_export_main_has_local_model_dir(self): + from QEfficient.cloud.export import main + + sig = inspect.signature(main) + assert "local_model_dir" in sig.parameters + + def test_export_main_has_full_batch_size(self): + from QEfficient.cloud.export import main + + sig = inspect.signature(main) + assert "full_batch_size" in sig.parameters + + def test_export_main_has_mxint8_kv_cache(self): + from QEfficient.cloud.export import main + + sig = inspect.signature(main) + assert "mxint8_kv_cache" in sig.parameters + + +# --------------------------------------------------------------------------- +# Tests: execute.py - function signatures +# --------------------------------------------------------------------------- + + +class TestExecuteFunctionSignatures: + """execute.py main function must have correct signature.""" + + def test_main_exists(self): + from QEfficient.cloud.execute import main + + assert callable(main) + + def test_main_has_model_name(self): + from QEfficient.cloud.execute import main + + sig = inspect.signature(main) + assert "model_name" in sig.parameters + + def test_main_has_qpc_path(self): + from QEfficient.cloud.execute import main + + sig = inspect.signature(main) + assert "qpc_path" in sig.parameters + + def test_main_has_device_group(self): + from QEfficient.cloud.execute import main + + sig = inspect.signature(main) + assert "device_group" in sig.parameters + + def test_main_has_prompt(self): + from QEfficient.cloud.execute import main + + sig = inspect.signature(main) + assert "prompt" in sig.parameters + + def test_main_has_prompts_txt_file_path(self): + from QEfficient.cloud.execute import main + + sig = inspect.signature(main) + assert "prompts_txt_file_path" in sig.parameters + + def test_main_has_generation_len(self): + from QEfficient.cloud.execute import main + + sig = inspect.signature(main) + assert "generation_len" in sig.parameters + + def test_main_has_cache_dir(self): + from QEfficient.cloud.execute import main + + sig = inspect.signature(main) + assert "cache_dir" in sig.parameters + + def test_main_has_hf_token(self): + from QEfficient.cloud.execute import main + + sig = inspect.signature(main) + assert "hf_token" in sig.parameters + + def test_main_has_local_model_dir(self): + from QEfficient.cloud.execute import main + + sig = inspect.signature(main) + assert "local_model_dir" in sig.parameters + + +# --------------------------------------------------------------------------- +# Tests: infer.py - function signatures +# --------------------------------------------------------------------------- + + +class TestInferFunctionSignatures: + """infer.py functions must have correct signatures.""" + + def test_main_exists(self): + from QEfficient.cloud.infer import main + + assert callable(main) + + def test_main_has_model_name(self): + from QEfficient.cloud.infer import main + + sig = inspect.signature(main) + assert "model_name" in sig.parameters + + def test_main_has_num_cores(self): + from QEfficient.cloud.infer import main + + sig = inspect.signature(main) + assert "num_cores" in sig.parameters + + def test_main_has_device_group(self): + from QEfficient.cloud.infer import main + + sig = inspect.signature(main) + assert "device_group" in sig.parameters + + def test_main_has_prompt(self): + from QEfficient.cloud.infer import main + + sig = inspect.signature(main) + assert "prompt" in sig.parameters + + def test_main_has_batch_size(self): + from QEfficient.cloud.infer import main + + sig = inspect.signature(main) + assert "batch_size" in sig.parameters + + def test_main_has_ctx_len(self): + from QEfficient.cloud.infer import main + + sig = inspect.signature(main) + assert "ctx_len" in sig.parameters + + def test_main_has_prompt_len(self): + from QEfficient.cloud.infer import main + + sig = inspect.signature(main) + assert "prompt_len" in sig.parameters + + def test_main_has_mxfp6(self): + from QEfficient.cloud.infer import main + + sig = inspect.signature(main) + assert "mxfp6" in sig.parameters + + def test_main_has_mxint8(self): + from QEfficient.cloud.infer import main + + sig = inspect.signature(main) + assert "mxint8" in sig.parameters + + def test_main_has_generation_len(self): + from QEfficient.cloud.infer import main + + sig = inspect.signature(main) + assert "generation_len" in sig.parameters + + def test_main_has_full_batch_size(self): + from QEfficient.cloud.infer import main + + sig = inspect.signature(main) + assert "full_batch_size" in sig.parameters + + def test_main_has_enable_qnn(self): + from QEfficient.cloud.infer import main + + sig = inspect.signature(main) + assert "enable_qnn" in sig.parameters + + def test_main_has_cache_dir(self): + from QEfficient.cloud.infer import main + + sig = inspect.signature(main) + assert "cache_dir" in sig.parameters + + def test_main_has_hf_token(self): + from QEfficient.cloud.infer import main + + sig = inspect.signature(main) + assert "hf_token" in sig.parameters + + def test_execute_vlm_model_exists(self): + from QEfficient.cloud.infer import execute_vlm_model + + assert callable(execute_vlm_model) + + def test_execute_vlm_model_has_qeff_model(self): + from QEfficient.cloud.infer import execute_vlm_model + + sig = inspect.signature(execute_vlm_model) + assert "qeff_model" in sig.parameters + + def test_execute_vlm_model_has_model_name(self): + from QEfficient.cloud.infer import execute_vlm_model + + sig = inspect.signature(execute_vlm_model) + assert "model_name" in sig.parameters + + def test_execute_vlm_model_has_image_url(self): + from QEfficient.cloud.infer import execute_vlm_model + + sig = inspect.signature(execute_vlm_model) + assert "image_url" in sig.parameters + + def test_execute_vlm_model_has_image_path(self): + from QEfficient.cloud.infer import execute_vlm_model + + sig = inspect.signature(execute_vlm_model) + assert "image_path" in sig.parameters + + def test_execute_vlm_model_has_prompt(self): + from QEfficient.cloud.infer import execute_vlm_model + + sig = inspect.signature(execute_vlm_model) + assert "prompt" in sig.parameters + + def test_execute_vlm_model_has_generation_len(self): + from QEfficient.cloud.infer import execute_vlm_model + + sig = inspect.signature(execute_vlm_model) + assert "generation_len" in sig.parameters + + +# --------------------------------------------------------------------------- +# Tests: infer.py - execute_vlm_model error handling +# --------------------------------------------------------------------------- + + +class TestExecuteVlmModelErrorHandling: + """execute_vlm_model must raise ValueError when no image is provided.""" + + def test_raises_without_image_url_or_path(self): + from QEfficient.cloud.infer import execute_vlm_model + + with pytest.raises(ValueError, match="Neither Image URL nor Image Path"): + execute_vlm_model( + qeff_model=MagicMock(), + model_name="test", + image_url=None, + image_path=None, + prompt=["test"], + ) + + def test_raises_with_empty_image_url_and_no_path(self): + from QEfficient.cloud.infer import execute_vlm_model + + with pytest.raises(ValueError): + execute_vlm_model( + qeff_model=MagicMock(), + model_name="test", + image_url="", + image_path=None, + prompt=["test"], + ) + + def test_raises_with_empty_image_path_and_no_url(self): + from QEfficient.cloud.infer import execute_vlm_model + + with pytest.raises(ValueError): + execute_vlm_model( + qeff_model=MagicMock(), + model_name="test", + image_url=None, + image_path="", + prompt=["test"], + ) + + +# --------------------------------------------------------------------------- +# Tests: finetune.py - function signatures +# --------------------------------------------------------------------------- + + +class TestFinetuneFunctionSignatures: + """finetune.py functions must have correct signatures.""" + + def test_setup_distributed_training_exists(self): + from QEfficient.cloud.finetune import setup_distributed_training + + assert callable(setup_distributed_training) + + def test_setup_distributed_training_has_train_config(self): + from QEfficient.cloud.finetune import setup_distributed_training + + sig = inspect.signature(setup_distributed_training) + assert "train_config" in sig.parameters + + def test_setup_seeds_exists(self): + from QEfficient.cloud.finetune import setup_seeds + + assert callable(setup_seeds) + + def test_setup_seeds_has_seed(self): + from QEfficient.cloud.finetune import setup_seeds + + sig = inspect.signature(setup_seeds) + assert "seed" in sig.parameters + + def test_load_model_and_tokenizer_exists(self): + from QEfficient.cloud.finetune import load_model_and_tokenizer + + assert callable(load_model_and_tokenizer) + + def test_load_model_and_tokenizer_has_train_config(self): + from QEfficient.cloud.finetune import load_model_and_tokenizer + + sig = inspect.signature(load_model_and_tokenizer) + assert "train_config" in sig.parameters + + def test_load_model_and_tokenizer_has_dataset_config(self): + from QEfficient.cloud.finetune import load_model_and_tokenizer + + sig = inspect.signature(load_model_and_tokenizer) + assert "dataset_config" in sig.parameters + + def test_apply_peft_exists(self): + from QEfficient.cloud.finetune import apply_peft + + assert callable(apply_peft) + + def test_apply_peft_has_model(self): + from QEfficient.cloud.finetune import apply_peft + + sig = inspect.signature(apply_peft) + assert "model" in sig.parameters + + def test_apply_peft_has_train_config(self): + from QEfficient.cloud.finetune import apply_peft + + sig = inspect.signature(apply_peft) + assert "train_config" in sig.parameters + + def test_setup_dataloaders_exists(self): + from QEfficient.cloud.finetune import setup_dataloaders + + assert callable(setup_dataloaders) + + def test_setup_dataloaders_has_train_config(self): + from QEfficient.cloud.finetune import setup_dataloaders + + sig = inspect.signature(setup_dataloaders) + assert "train_config" in sig.parameters + + def test_setup_dataloaders_has_dataset_config(self): + from QEfficient.cloud.finetune import setup_dataloaders + + sig = inspect.signature(setup_dataloaders) + assert "dataset_config" in sig.parameters + + def test_setup_dataloaders_has_tokenizer(self): + from QEfficient.cloud.finetune import setup_dataloaders + + sig = inspect.signature(setup_dataloaders) + assert "tokenizer" in sig.parameters + + def test_main_exists(self): + from QEfficient.cloud.finetune import main + + assert callable(main) + + +# --------------------------------------------------------------------------- +# Tests: finetune.py - setup_seeds behavior +# --------------------------------------------------------------------------- + + +class TestSetupSeeds: + """setup_seeds must set random seeds correctly.""" + + def test_setup_seeds_does_not_crash(self): + from QEfficient.cloud.finetune import setup_seeds + + setup_seeds(42) + + def test_setup_seeds_with_different_values(self): + from QEfficient.cloud.finetune import setup_seeds + + for seed in [0, 1, 42, 100, 9999]: + setup_seeds(seed) + + def test_setup_seeds_torch_reproducibility(self): + import torch + + from QEfficient.cloud.finetune import setup_seeds + + setup_seeds(42) + torch.manual_seed(42) + a = torch.rand(5).tolist() + torch.manual_seed(42) + b = torch.rand(5).tolist() + assert a == b, "torch.manual_seed must produce reproducible results" + + def test_setup_seeds_numpy_reproducibility(self): + import numpy as np + + from QEfficient.cloud.finetune import setup_seeds + + setup_seeds(42) + np.random.seed(42) + a = np.random.rand(5).tolist() + np.random.seed(42) + b = np.random.rand(5).tolist() + assert a == b, "np.random.seed must produce reproducible results" + + +# --------------------------------------------------------------------------- +# Tests: finetune.py - apply_peft behavior +# --------------------------------------------------------------------------- + + +class TestApplyPeft: + """apply_peft must return model unchanged when use_peft=False.""" + + def test_apply_peft_returns_model_when_peft_disabled(self): + from QEfficient.cloud.finetune import apply_peft + from QEfficient.finetune.configs.training import TrainConfig + + train_config = TrainConfig() + train_config.use_peft = False + + mock_model = MagicMock() + result = apply_peft(mock_model, train_config) + assert result is mock_model, "apply_peft must return original model when use_peft=False" + + def test_apply_peft_does_not_modify_model_when_disabled(self): + from QEfficient.cloud.finetune import apply_peft + from QEfficient.finetune.configs.training import TrainConfig + + train_config = TrainConfig() + train_config.use_peft = False + + mock_model = MagicMock() + original_id = id(mock_model) + result = apply_peft(mock_model, train_config) + assert id(result) == original_id + + +# --------------------------------------------------------------------------- +# Tests: Argument parsing - compile.py +# --------------------------------------------------------------------------- + + +class TestCompileArgumentParsing: + """compile.py argument parser must handle required and optional args.""" + + def _get_parser(self): + parser = argparse.ArgumentParser(description="Compilation script.") + parser.add_argument("--onnx_path", "--onnx-path", required=True) + parser.add_argument("--qpc-path", "--qpc_path", required=True) + parser.add_argument("--batch_size", "--batch-size", type=int, default=1) + parser.add_argument("--prompt_len", "--prompt-len", default=32, type=int) + parser.add_argument("--ctx_len", "--ctx-len", default=128, type=int) + parser.add_argument("--mxfp6", action="store_true") + parser.add_argument("--mxint8", action="store_true") + parser.add_argument("--num_cores", "--num-cores", required=True, type=int) + parser.add_argument( + "--device_group", + "--device-group", + required=True, + type=lambda device_ids: [int(x) for x in device_ids.strip("[]").split(",")], + ) + parser.add_argument("--aic_enable_depth_first", "--aic-enable-depth-first", action="store_true") + parser.add_argument("--mos", type=int, default=-1) + parser.add_argument("--full_batch_size", "--full-batch-size", type=int, default=None) + return parser + + def test_parser_requires_onnx_path(self): + parser = self._get_parser() + with pytest.raises(SystemExit): + parser.parse_args([]) + + def test_parser_requires_num_cores(self): + parser = self._get_parser() + with pytest.raises(SystemExit): + parser.parse_args(["--onnx_path", "/path/to/model.onnx", "--qpc-path", "/path/to/qpc"]) + + def test_parser_requires_device_group(self): + parser = self._get_parser() + with pytest.raises(SystemExit): + parser.parse_args(["--onnx_path", "/path/to/model.onnx", "--qpc-path", "/path/to/qpc", "--num-cores", "16"]) + + def test_parser_accepts_all_required_args(self): + parser = self._get_parser() + args = parser.parse_args( + [ + "--onnx_path", + "/path/to/model.onnx", + "--qpc-path", + "/path/to/qpc", + "--num-cores", + "16", + "--device-group", + "[0]", + ] + ) + assert args.onnx_path == "/path/to/model.onnx" + assert args.num_cores == 16 + + def test_parser_default_batch_size_is_1(self): + parser = self._get_parser() + args = parser.parse_args( + [ + "--onnx_path", + "/path/to/model.onnx", + "--qpc-path", + "/path/to/qpc", + "--num-cores", + "16", + "--device-group", + "[0]", + ] + ) + assert args.batch_size == 1 + + def test_parser_default_prompt_len_is_32(self): + parser = self._get_parser() + args = parser.parse_args( + [ + "--onnx_path", + "/path/to/model.onnx", + "--qpc-path", + "/path/to/qpc", + "--num-cores", + "16", + "--device-group", + "[0]", + ] + ) + assert args.prompt_len == 32 + + def test_parser_default_ctx_len_is_128(self): + parser = self._get_parser() + args = parser.parse_args( + [ + "--onnx_path", + "/path/to/model.onnx", + "--qpc-path", + "/path/to/qpc", + "--num-cores", + "16", + "--device-group", + "[0]", + ] + ) + assert args.ctx_len == 128 + + def test_parser_accepts_batch_size(self): + parser = self._get_parser() + args = parser.parse_args( + [ + "--onnx_path", + "/path/to/model.onnx", + "--qpc-path", + "/path/to/qpc", + "--num-cores", + "16", + "--device-group", + "[0]", + "--batch-size", + "4", + ] + ) + assert args.batch_size == 4 + + def test_parser_accepts_multi_device_group(self): + parser = self._get_parser() + args = parser.parse_args( + [ + "--onnx_path", + "/path/to/model.onnx", + "--qpc-path", + "/path/to/qpc", + "--num-cores", + "16", + "--device-group", + "[0,1,2,3]", + ] + ) + assert args.device_group == [0, 1, 2, 3] + + def test_parser_accepts_mxfp6_flag(self): + parser = self._get_parser() + args = parser.parse_args( + [ + "--onnx_path", + "/path/to/model.onnx", + "--qpc-path", + "/path/to/qpc", + "--num-cores", + "16", + "--device-group", + "[0]", + "--mxfp6", + ] + ) + assert args.mxfp6 is True + + def test_parser_accepts_mxint8_flag(self): + parser = self._get_parser() + args = parser.parse_args( + [ + "--onnx_path", + "/path/to/model.onnx", + "--qpc-path", + "/path/to/qpc", + "--num-cores", + "16", + "--device-group", + "[0]", + "--mxint8", + ] + ) + assert args.mxint8 is True + + def test_parser_accepts_aic_enable_depth_first(self): + parser = self._get_parser() + args = parser.parse_args( + [ + "--onnx_path", + "/path/to/model.onnx", + "--qpc-path", + "/path/to/qpc", + "--num-cores", + "16", + "--device-group", + "[0]", + "--aic-enable-depth-first", + ] + ) + assert args.aic_enable_depth_first is True + + def test_parser_accepts_full_batch_size(self): + parser = self._get_parser() + args = parser.parse_args( + [ + "--onnx_path", + "/path/to/model.onnx", + "--qpc-path", + "/path/to/qpc", + "--num-cores", + "16", + "--device-group", + "[0]", + "--full-batch-size", + "8", + ] + ) + assert args.full_batch_size == 8 + + def test_parser_default_full_batch_size_is_none(self): + parser = self._get_parser() + args = parser.parse_args( + [ + "--onnx_path", + "/path/to/model.onnx", + "--qpc-path", + "/path/to/qpc", + "--num-cores", + "16", + "--device-group", + "[0]", + ] + ) + assert args.full_batch_size is None + + +# --------------------------------------------------------------------------- +# Tests: Argument parsing - execute.py +# --------------------------------------------------------------------------- + + +class TestExecuteArgumentParsing: + """execute.py argument parser must handle required and optional args.""" + + def _get_parser(self): + parser = argparse.ArgumentParser(description="Execution script.") + parser.add_argument("--model_name", "--model-name", required=False, type=str) + parser.add_argument("--qpc_path", "--qpc-path", required=True) + parser.add_argument( + "--device_group", + "--device-group", + type=lambda device_ids: [int(x) for x in device_ids.strip("[]").split(",")], + ) + parser.add_argument("--prompt", type=lambda prompt: prompt.split("|")) + parser.add_argument("--prompts_txt_file_path", "--prompts-txt-file-path", type=str) + parser.add_argument("--generation_len", "--generation-len", type=int) + parser.add_argument("--local-model-dir", "--local_model_dir", required=False) + parser.add_argument("--cache-dir", "--cache_dir", default=None, required=False) + parser.add_argument("--full_batch_size", "--full-batch-size", type=int, default=None) + parser.add_argument("--hf-token", "--hf_token", default=None, type=str, required=False) + return parser + + def test_parser_requires_qpc_path(self): + parser = self._get_parser() + with pytest.raises(SystemExit): + parser.parse_args([]) + + def test_parser_accepts_qpc_path(self): + parser = self._get_parser() + args = parser.parse_args(["--qpc_path", "/path/to/qpc"]) + assert args.qpc_path == "/path/to/qpc" + + def test_parser_accepts_model_name(self): + parser = self._get_parser() + args = parser.parse_args(["--qpc_path", "/path/to/qpc", "--model_name", "gpt2"]) + assert args.model_name == "gpt2" + + def test_parser_accepts_prompt_with_pipe(self): + parser = self._get_parser() + args = parser.parse_args(["--qpc_path", "/path/to/qpc", "--prompt", "Hello|World|Test"]) + assert args.prompt == ["Hello", "World", "Test"] + + def test_parser_accepts_single_prompt(self): + parser = self._get_parser() + args = parser.parse_args(["--qpc_path", "/path/to/qpc", "--prompt", "Hello world"]) + assert args.prompt == ["Hello world"] + + def test_parser_accepts_generation_len(self): + parser = self._get_parser() + args = parser.parse_args(["--qpc_path", "/path/to/qpc", "--generation-len", "100"]) + assert args.generation_len == 100 + + def test_parser_accepts_device_group(self): + parser = self._get_parser() + args = parser.parse_args(["--qpc_path", "/path/to/qpc", "--device-group", "[0,1]"]) + assert args.device_group == [0, 1] + + def test_parser_default_generation_len_is_none(self): + parser = self._get_parser() + args = parser.parse_args(["--qpc_path", "/path/to/qpc"]) + assert args.generation_len is None + + def test_parser_accepts_hf_token(self): + parser = self._get_parser() + args = parser.parse_args(["--qpc_path", "/path/to/qpc", "--hf-token", "hf_abc123"]) + assert args.hf_token == "hf_abc123" + + +# --------------------------------------------------------------------------- +# Tests: Argument parsing - export.py +# --------------------------------------------------------------------------- + + +class TestExportArgumentParsing: + """export.py argument parser must handle required and optional args.""" + + def _get_parser(self): + parser = argparse.ArgumentParser(description="Export script.") + parser.add_argument("--model_name", "--model-name", required=True) + parser.add_argument("--local-model-dir", "--local_model_dir", required=False) + parser.add_argument("--cache_dir", "--cache-dir", required=False) + parser.add_argument("--hf-token", "--hf_token", default=None, type=str, required=False) + parser.add_argument("--full_batch_size", "--full-batch-size", type=int, default=None) + parser.add_argument("--mxint8_kv_cache", "--mxint8-kv-cache", required=False) + return parser + + def test_parser_requires_model_name(self): + parser = self._get_parser() + with pytest.raises(SystemExit): + parser.parse_args([]) + + def test_parser_accepts_model_name(self): + parser = self._get_parser() + args = parser.parse_args(["--model_name", "gpt2"]) + assert args.model_name == "gpt2" + + def test_parser_accepts_cache_dir(self): + parser = self._get_parser() + args = parser.parse_args(["--model_name", "gpt2", "--cache-dir", "/path/to/cache"]) + assert args.cache_dir == "/path/to/cache" + + def test_parser_accepts_hf_token(self): + parser = self._get_parser() + args = parser.parse_args(["--model_name", "gpt2", "--hf-token", "hf_token123"]) + assert args.hf_token == "hf_token123" + + def test_parser_accepts_full_batch_size(self): + parser = self._get_parser() + args = parser.parse_args(["--model_name", "gpt2", "--full-batch-size", "4"]) + assert args.full_batch_size == 4 + + def test_parser_default_full_batch_size_is_none(self): + parser = self._get_parser() + args = parser.parse_args(["--model_name", "gpt2"]) + assert args.full_batch_size is None + + +# --------------------------------------------------------------------------- +# Tests: Argument parsing - infer.py +# --------------------------------------------------------------------------- + + +class TestInferArgumentParsing: + """infer.py argument parser must handle required and optional args.""" + + def _get_parser(self): + parser = argparse.ArgumentParser(description="Inference script.") + parser.add_argument("--model-name", "--model_name", required=True, type=str) + parser.add_argument("--batch-size", "--batch_size", type=int, default=1) + parser.add_argument("--prompt-len", "--prompt_len", default=32, type=int) + parser.add_argument("--ctx-len", "--ctx_len", default=128, type=int) + parser.add_argument("--num_cores", "--num-cores", type=int, required=True) + parser.add_argument( + "--device_group", + "--device-group", + type=lambda device_ids: [int(x) for x in device_ids.strip("[]").split(",")], + ) + parser.add_argument("--prompt", type=lambda prompt: prompt.split("|")) + parser.add_argument("--generation_len", "--generation-len", type=int) + parser.add_argument("--mxfp6", "--mxfp6_matmul", "--mxfp6-matmul", action="store_true") + parser.add_argument("--mxint8", "--mxint8_kv_cache", "--mxint8-kv-cache", action="store_true") + parser.add_argument("--full_batch_size", "--full-batch-size", type=int, default=None) + parser.add_argument("--aic_enable_depth_first", "--aic-enable-depth-first", action="store_true") + parser.add_argument("--mos", type=int, default=1) + parser.add_argument("--cache-dir", "--cache_dir", default=None, required=False) + parser.add_argument("--hf-token", "--hf_token", default=None, type=str, required=False) + parser.add_argument("--trust_remote_code", action="store_true", default=False) + return parser + + def test_parser_requires_model_name(self): + parser = self._get_parser() + with pytest.raises(SystemExit): + parser.parse_args([]) + + def test_parser_requires_num_cores(self): + parser = self._get_parser() + with pytest.raises(SystemExit): + parser.parse_args(["--model-name", "gpt2"]) + + def test_parser_accepts_all_required_args(self): + parser = self._get_parser() + args = parser.parse_args(["--model-name", "gpt2", "--num-cores", "16"]) + assert args.model_name == "gpt2" + assert args.num_cores == 16 + + def test_parser_default_batch_size_is_1(self): + parser = self._get_parser() + args = parser.parse_args(["--model-name", "gpt2", "--num-cores", "16"]) + assert args.batch_size == 1 + + def test_parser_default_prompt_len_is_32(self): + parser = self._get_parser() + args = parser.parse_args(["--model-name", "gpt2", "--num-cores", "16"]) + assert args.prompt_len == 32 + + def test_parser_default_ctx_len_is_128(self): + parser = self._get_parser() + args = parser.parse_args(["--model-name", "gpt2", "--num-cores", "16"]) + assert args.ctx_len == 128 + + def test_parser_accepts_mxfp6_flag(self): + parser = self._get_parser() + args = parser.parse_args(["--model-name", "gpt2", "--num-cores", "16", "--mxfp6"]) + assert args.mxfp6 is True + + def test_parser_accepts_mxint8_flag(self): + parser = self._get_parser() + args = parser.parse_args(["--model-name", "gpt2", "--num-cores", "16", "--mxint8"]) + assert args.mxint8 is True + + def test_parser_accepts_aic_enable_depth_first(self): + parser = self._get_parser() + args = parser.parse_args(["--model-name", "gpt2", "--num-cores", "16", "--aic-enable-depth-first"]) + assert args.aic_enable_depth_first is True + + def test_parser_accepts_full_batch_size(self): + parser = self._get_parser() + args = parser.parse_args(["--model-name", "gpt2", "--num-cores", "16", "--full-batch-size", "8"]) + assert args.full_batch_size == 8 + + def test_parser_accepts_trust_remote_code(self): + parser = self._get_parser() + args = parser.parse_args(["--model-name", "gpt2", "--num-cores", "16", "--trust_remote_code"]) + assert args.trust_remote_code is True + + def test_parser_default_trust_remote_code_is_false(self): + parser = self._get_parser() + args = parser.parse_args(["--model-name", "gpt2", "--num-cores", "16"]) + assert args.trust_remote_code is False + + def test_parser_accepts_prompt_with_pipe(self): + parser = self._get_parser() + args = parser.parse_args(["--model-name", "gpt2", "--num-cores", "16", "--prompt", "Hello|World"]) + assert args.prompt == ["Hello", "World"] + + def test_parser_accepts_device_group(self): + parser = self._get_parser() + args = parser.parse_args(["--model-name", "gpt2", "--num-cores", "16", "--device-group", "[0,1]"]) + assert args.device_group == [0, 1] + + +# --------------------------------------------------------------------------- +# Tests: Device group parsing utility +# --------------------------------------------------------------------------- + + +class TestDeviceGroupParsing: + """Device group lambda parser must correctly parse various formats.""" + + def _parse_device_group(self, s): + return [int(x) for x in s.strip("[]").split(",")] + + def test_single_device(self): + result = self._parse_device_group("[0]") + assert result == [0] + + def test_two_devices(self): + result = self._parse_device_group("[0,1]") + assert result == [0, 1] + + def test_four_devices(self): + result = self._parse_device_group("[0,1,2,3]") + assert result == [0, 1, 2, 3] + + def test_device_with_spaces(self): + result = self._parse_device_group("[0, 1, 2]") + assert result == [0, 1, 2] + + def test_single_digit_device(self): + result = self._parse_device_group("[7]") + assert result == [7] + + +# --------------------------------------------------------------------------- +# Tests: Prompt parsing utility +# --------------------------------------------------------------------------- + + +class TestPromptParsing: + """Prompt pipe-split lambda must correctly parse prompts.""" + + def _parse_prompt(self, s): + return s.split("|") + + def test_single_prompt(self): + result = self._parse_prompt("Hello world") + assert result == ["Hello world"] + + def test_two_prompts(self): + result = self._parse_prompt("Hello|World") + assert result == ["Hello", "World"] + + def test_three_prompts(self): + result = self._parse_prompt("A|B|C") + assert result == ["A", "B", "C"] + + def test_prompt_with_spaces(self): + result = self._parse_prompt("Hello world|How are you") + assert result == ["Hello world", "How are you"] + + def test_empty_prompt(self): + result = self._parse_prompt("") + assert result == [""] + + +# --------------------------------------------------------------------------- +# Tests: TrainConfig importability and defaults +# --------------------------------------------------------------------------- + + +class TestTrainConfig: + """TrainConfig must be importable and have correct defaults.""" + + def test_train_config_importable(self): + from QEfficient.finetune.configs.training import TrainConfig + + assert TrainConfig is not None + + def test_train_config_instantiable(self): + from QEfficient.finetune.configs.training import TrainConfig + + cfg = TrainConfig() + assert cfg is not None + + def test_train_config_has_model_name(self): + from QEfficient.finetune.configs.training import TrainConfig + + cfg = TrainConfig() + assert hasattr(cfg, "model_name") + + def test_train_config_has_use_peft(self): + from QEfficient.finetune.configs.training import TrainConfig + + cfg = TrainConfig() + assert hasattr(cfg, "use_peft") + + def test_train_config_has_seed(self): + from QEfficient.finetune.configs.training import TrainConfig + + cfg = TrainConfig() + assert hasattr(cfg, "seed") + + def test_train_config_has_device(self): + from QEfficient.finetune.configs.training import TrainConfig + + cfg = TrainConfig() + assert hasattr(cfg, "device") + + def test_train_config_has_enable_ddp(self): + from QEfficient.finetune.configs.training import TrainConfig + + cfg = TrainConfig() + assert hasattr(cfg, "enable_ddp") + + def test_train_config_has_lr(self): + from QEfficient.finetune.configs.training import TrainConfig + + cfg = TrainConfig() + assert hasattr(cfg, "lr") + + def test_train_config_has_gradient_checkpointing(self): + from QEfficient.finetune.configs.training import TrainConfig + + cfg = TrainConfig() + assert hasattr(cfg, "gradient_checkpointing") + + def test_train_config_use_peft_default_is_true(self): + from QEfficient.finetune.configs.training import TrainConfig + + cfg = TrainConfig() + assert cfg.use_peft is True + + def test_train_config_enable_ddp_default_is_false(self): + from QEfficient.finetune.configs.training import TrainConfig + + cfg = TrainConfig() + assert cfg.enable_ddp is False + + +# --------------------------------------------------------------------------- +# Tests: setup_distributed_training with DDP disabled +# --------------------------------------------------------------------------- + + +class TestSetupDistributedTraining: + """setup_distributed_training must handle non-DDP case without error.""" + + def test_non_ddp_cpu_does_not_crash(self): + from QEfficient.cloud.finetune import setup_distributed_training + from QEfficient.finetune.configs.training import TrainConfig + + train_config = TrainConfig() + train_config.enable_ddp = False + train_config.device = "cpu" + # Should not raise + setup_distributed_training(train_config) + + def test_non_ddp_returns_none(self): + from QEfficient.cloud.finetune import setup_distributed_training + from QEfficient.finetune.configs.training import TrainConfig + + train_config = TrainConfig() + train_config.enable_ddp = False + train_config.device = "cpu" + result = setup_distributed_training(train_config) + assert result is None + + +# --------------------------------------------------------------------------- +# Tests: check_and_assign_cache_dir utility +# --------------------------------------------------------------------------- + + +class TestCheckAndAssignCacheDir: + """check_and_assign_cache_dir must return correct cache directory.""" + + def test_function_importable(self): + from QEfficient.utils import check_and_assign_cache_dir + + assert callable(check_and_assign_cache_dir) + + def test_returns_cache_dir_when_provided(self): + from QEfficient.utils import check_and_assign_cache_dir + + result = check_and_assign_cache_dir(local_model_dir=None, cache_dir="/my/cache") + assert result == "/my/cache" + + def test_returns_default_when_local_model_dir_provided(self): + from QEfficient.utils import check_and_assign_cache_dir + + result = check_and_assign_cache_dir(local_model_dir="/local/model", cache_dir=None) + # When local_model_dir is provided, cache_dir should be None or default + assert result is None or isinstance(result, str) + + def test_returns_string_or_none(self): + from QEfficient.utils import check_and_assign_cache_dir + + result = check_and_assign_cache_dir(local_model_dir=None, cache_dir=None) + assert result is None or isinstance(result, str) diff --git a/tests/unit_test/utils/test_diffusers.py b/tests/unit_test/utils/test_diffusers.py new file mode 100644 index 000000000..f048df806 --- /dev/null +++ b/tests/unit_test/utils/test_diffusers.py @@ -0,0 +1,1124 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +# SPDX-License-Identifier: BSD-3-Clause +# +# ----------------------------------------------------------------------------- +""" +CPU-only tests for QEfficient/diffusers module. + +Tests verify: + - Module importability (all diffusers sub-modules) + - Attention blocking config parsing (get_attention_blocking_config) + - Attention blocking functions: apply_head_blocking, apply_kv_blocking, + apply_q_blocking, apply_qkv_blocking, compute_blocked_attention + - QEff normalization layers: QEffAdaLayerNormZero, QEffAdaLayerNormZeroSingle, + QEffAdaLayerNormContinuous + - Diffusers transforms structure: CustomOpsTransform, AttentionTransform, + NormalizationTransform + - Pipeline utilities: calculate_compressed_latent_dimension, + calculate_latent_dimensions_with_frames, ModulePerf, QEffPipelineOutput + - Pipeline module class structure: QEffTextEncoder, QEffVAE, + QEffFluxTransformerModel, QEffWanUnifiedTransformer + - Flux transformer blocks: QEffFluxTransformerBlock, + QEffFluxSingleTransformerBlock, QEffFluxTransformer2DModel (tiny in-memory) + +All tests run on CPU only. No QAIC hardware required. No network downloads. +""" + +import os + +import pytest +import torch +import torch.nn.functional as F + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _standard_attention(q, k, v, attention_mask=None): + """Reference standard scaled-dot-product attention (BS, NH, CL, DH).""" + scale = q.shape[-1] ** -0.5 + scores = torch.matmul(q, k.transpose(-2, -1)) * scale + if attention_mask is not None: + scores = scores + attention_mask + weights = F.softmax(scores, dim=-1) + return torch.matmul(weights, v) + + +def _make_qkv(bs=1, nh=2, cl=8, dh=16): + """Build random (q, k, v) tensors of shape (BS, NH, CL, DH).""" + q = torch.randn(bs, nh, cl, dh) + k = torch.randn(bs, nh, cl, dh) + v = torch.randn(bs, nh, cl, dh) + return q, k, v + + +# --------------------------------------------------------------------------- +# 1. Module importability +# --------------------------------------------------------------------------- + + +@pytest.mark.diffusers +class TestDiffusersModuleImportability: + """All QEfficient/diffusers sub-modules must be importable on CPU.""" + + def test_diffusers_init_importable(self): + import QEfficient.diffusers + + assert QEfficient.diffusers is not None + + def test_modeling_utils_importable(self): + import QEfficient.diffusers.models.modeling_utils + + assert QEfficient.diffusers.models.modeling_utils is not None + + def test_normalization_importable(self): + import QEfficient.diffusers.models.normalization + + assert QEfficient.diffusers.models.normalization is not None + + def test_pytorch_transforms_importable(self): + import QEfficient.diffusers.models.pytorch_transforms + + assert QEfficient.diffusers.models.pytorch_transforms is not None + + def test_transformer_flux_importable(self): + import QEfficient.diffusers.models.transformers.transformer_flux + + assert QEfficient.diffusers.models.transformers.transformer_flux is not None + + def test_pipeline_utils_importable(self): + import QEfficient.diffusers.pipelines.pipeline_utils + + assert QEfficient.diffusers.pipelines.pipeline_utils is not None + + def test_pipeline_module_importable(self): + import QEfficient.diffusers.pipelines.pipeline_module + + assert QEfficient.diffusers.pipelines.pipeline_module is not None + + def test_get_attention_blocking_config_importable(self): + from QEfficient.diffusers.models.modeling_utils import get_attention_blocking_config + + assert callable(get_attention_blocking_config) + + def test_compute_blocked_attention_importable(self): + from QEfficient.diffusers.models.modeling_utils import compute_blocked_attention + + assert callable(compute_blocked_attention) + + def test_qeff_flux_transformer_2d_model_importable(self): + from QEfficient.diffusers.models.transformers.transformer_flux import QEffFluxTransformer2DModel + + assert QEffFluxTransformer2DModel is not None + + def test_qeff_ada_layer_norm_zero_importable(self): + from QEfficient.diffusers.models.normalization import QEffAdaLayerNormZero + + assert QEffAdaLayerNormZero is not None + + def test_qeff_pipeline_output_importable(self): + from QEfficient.diffusers.pipelines.pipeline_utils import QEffPipelineOutput + + assert QEffPipelineOutput is not None + + +# --------------------------------------------------------------------------- +# 2. Attention blocking config +# --------------------------------------------------------------------------- + + +@pytest.mark.diffusers +class TestAttentionBlockingConfig: + """get_attention_blocking_config must parse env vars correctly.""" + + def _get_config(self, mode=None, head_block=None, kv_blocks=None, q_blocks=None): + """Helper: set env vars, call get_attention_blocking_config, restore.""" + from QEfficient.diffusers.models.modeling_utils import get_attention_blocking_config + + env_backup = {} + keys = { + "ATTENTION_BLOCKING_MODE": mode, + "head_block_size": head_block, + "num_kv_blocks": kv_blocks, + "num_q_blocks": q_blocks, + } + for k, v in keys.items(): + env_backup[k] = os.environ.get(k) + if v is not None: + os.environ[k] = str(v) + elif k in os.environ: + del os.environ[k] + try: + return get_attention_blocking_config() + finally: + for k, v in env_backup.items(): + if v is None: + os.environ.pop(k, None) + else: + os.environ[k] = v + + def test_default_mode_is_default(self): + blocking_mode, _, _, _ = self._get_config() + assert blocking_mode == "default", f"Default blocking mode must be 'default', got '{blocking_mode}'" + + def test_default_head_block_size_is_none_or_positive(self): + """Default head_block_size is None (unused in 'default' mode) or a positive int.""" + _, head_block_size, _, _ = self._get_config() + assert head_block_size is None or head_block_size > 0 + + def test_default_num_kv_blocks_is_none_or_positive(self): + """Default num_kv_blocks is None (unused in 'default' mode) or a positive int.""" + _, _, num_kv_blocks, _ = self._get_config() + assert num_kv_blocks is None or num_kv_blocks > 0 + + def test_default_num_q_blocks_is_none_or_positive(self): + """Default num_q_blocks is None (unused in 'default' mode) or a positive int.""" + _, _, _, num_q_blocks = self._get_config() + assert num_q_blocks is None or num_q_blocks > 0 + + def test_custom_mode_kv(self): + blocking_mode, _, _, _ = self._get_config(mode="kv") + assert blocking_mode == "kv" + + def test_custom_mode_q(self): + blocking_mode, _, _, _ = self._get_config(mode="q") + assert blocking_mode == "q" + + def test_custom_mode_qkv(self): + blocking_mode, _, _, _ = self._get_config(mode="qkv") + assert blocking_mode == "qkv" + + def test_custom_head_block_size(self): + _, head_block_size, _, _ = self._get_config(head_block=4) + assert head_block_size == 4 + + def test_custom_num_kv_blocks(self): + _, _, num_kv_blocks, _ = self._get_config(kv_blocks=8) + assert num_kv_blocks == 8 + + def test_custom_num_q_blocks(self): + _, _, _, num_q_blocks = self._get_config(q_blocks=16) + assert num_q_blocks == 16 + + def test_returns_four_values(self): + result = self._get_config() + assert len(result) == 4 + + def test_invalid_mode_raises_value_error(self): + from QEfficient.diffusers.models.modeling_utils import get_attention_blocking_config + + os.environ["ATTENTION_BLOCKING_MODE"] = "invalid_xyz_mode" + try: + with pytest.raises((ValueError, KeyError)): + get_attention_blocking_config() + finally: + del os.environ["ATTENTION_BLOCKING_MODE"] + + +# --------------------------------------------------------------------------- +# 3. Head blocking attention +# --------------------------------------------------------------------------- + + +@pytest.mark.diffusers +@pytest.mark.accuracy +class TestHeadBlockingAttention: + """apply_head_blocking must produce correct outputs on CPU.""" + + def test_output_shape_matches_input(self): + from QEfficient.diffusers.models.modeling_utils import apply_head_blocking + + q, k, v = _make_qkv(bs=1, nh=4, cl=8, dh=16) + out = apply_head_blocking(q, k, v, head_block_size=2) + assert out.shape == q.shape, f"Expected {q.shape}, got {out.shape}" + + def test_output_is_finite(self): + from QEfficient.diffusers.models.modeling_utils import apply_head_blocking + + q, k, v = _make_qkv(bs=1, nh=4, cl=8, dh=16) + out = apply_head_blocking(q, k, v, head_block_size=2) + assert torch.isfinite(out).all(), "apply_head_blocking output contains NaN/Inf" + + def test_small_seq_matches_standard_attention(self): + """For CL <= 512, head blocking must match standard attention exactly.""" + from QEfficient.diffusers.models.modeling_utils import apply_head_blocking + + q, k, v = _make_qkv(bs=1, nh=2, cl=8, dh=16) + ref = _standard_attention(q, k, v) + out = apply_head_blocking(q, k, v, head_block_size=1) + max_diff = (ref - out).abs().max().item() + assert max_diff < 1e-4, f"Head blocking vs standard attention max_diff={max_diff:.2e}" + + def test_batch_size_2_works(self): + from QEfficient.diffusers.models.modeling_utils import apply_head_blocking + + q, k, v = _make_qkv(bs=2, nh=4, cl=8, dh=16) + out = apply_head_blocking(q, k, v, head_block_size=2) + assert out.shape == q.shape + assert torch.isfinite(out).all() + + def test_single_head_block_size_equals_num_heads(self): + """head_block_size == num_heads should process all heads at once.""" + from QEfficient.diffusers.models.modeling_utils import apply_head_blocking + + q, k, v = _make_qkv(bs=1, nh=4, cl=8, dh=16) + out = apply_head_blocking(q, k, v, head_block_size=4) + assert out.shape == q.shape + assert torch.isfinite(out).all() + + +# --------------------------------------------------------------------------- +# 4. KV blocking attention +# --------------------------------------------------------------------------- + + +@pytest.mark.diffusers +@pytest.mark.accuracy +class TestKVBlockingAttention: + """apply_kv_blocking must produce correct outputs on CPU.""" + + def test_output_shape_matches_input(self): + from QEfficient.diffusers.models.modeling_utils import apply_kv_blocking + + q, k, v = _make_qkv(bs=1, nh=2, cl=8, dh=16) + out = apply_kv_blocking(q, k, v, head_block_size=2, num_kv_blocks=2) + assert out.shape == q.shape, f"Expected {q.shape}, got {out.shape}" + + def test_output_is_finite(self): + from QEfficient.diffusers.models.modeling_utils import apply_kv_blocking + + q, k, v = _make_qkv(bs=1, nh=2, cl=8, dh=16) + out = apply_kv_blocking(q, k, v, head_block_size=2, num_kv_blocks=2) + assert torch.isfinite(out).all() + + def test_small_seq_matches_standard_attention(self): + """For CL <= 512, kv blocking must match standard attention.""" + from QEfficient.diffusers.models.modeling_utils import apply_kv_blocking + + q, k, v = _make_qkv(bs=1, nh=2, cl=8, dh=16) + ref = _standard_attention(q, k, v) + out = apply_kv_blocking(q, k, v, head_block_size=2, num_kv_blocks=1) + max_diff = (ref - out).abs().max().item() + assert max_diff < 1e-4, f"KV blocking vs standard attention max_diff={max_diff:.2e}" + + def test_batch_size_2_works(self): + from QEfficient.diffusers.models.modeling_utils import apply_kv_blocking + + q, k, v = _make_qkv(bs=2, nh=2, cl=8, dh=16) + out = apply_kv_blocking(q, k, v, head_block_size=2, num_kv_blocks=2) + assert out.shape == q.shape + assert torch.isfinite(out).all() + + +# --------------------------------------------------------------------------- +# 5. Q blocking attention +# --------------------------------------------------------------------------- + + +@pytest.mark.diffusers +@pytest.mark.accuracy +class TestQBlockingAttention: + """apply_q_blocking must produce correct outputs on CPU.""" + + def test_output_shape_matches_input(self): + from QEfficient.diffusers.models.modeling_utils import apply_q_blocking + + q, k, v = _make_qkv(bs=1, nh=2, cl=8, dh=16) + out = apply_q_blocking(q, k, v, head_block_size=2, num_q_blocks=2) + assert out.shape == q.shape, f"Expected {q.shape}, got {out.shape}" + + def test_output_is_finite(self): + from QEfficient.diffusers.models.modeling_utils import apply_q_blocking + + q, k, v = _make_qkv(bs=1, nh=2, cl=8, dh=16) + out = apply_q_blocking(q, k, v, head_block_size=2, num_q_blocks=2) + assert torch.isfinite(out).all() + + def test_small_seq_matches_standard_attention(self): + """For CL <= 512, q blocking must match standard attention.""" + from QEfficient.diffusers.models.modeling_utils import apply_q_blocking + + q, k, v = _make_qkv(bs=1, nh=2, cl=8, dh=16) + ref = _standard_attention(q, k, v) + out = apply_q_blocking(q, k, v, head_block_size=2, num_q_blocks=1) + max_diff = (ref - out).abs().max().item() + assert max_diff < 1e-4, f"Q blocking vs standard attention max_diff={max_diff:.2e}" + + def test_batch_size_2_works(self): + from QEfficient.diffusers.models.modeling_utils import apply_q_blocking + + q, k, v = _make_qkv(bs=2, nh=2, cl=8, dh=16) + out = apply_q_blocking(q, k, v, head_block_size=2, num_q_blocks=2) + assert out.shape == q.shape + assert torch.isfinite(out).all() + + +# --------------------------------------------------------------------------- +# 6. QKV blocking attention +# --------------------------------------------------------------------------- + + +@pytest.mark.diffusers +@pytest.mark.accuracy +class TestQKVBlockingAttention: + """apply_qkv_blocking must produce correct outputs on CPU.""" + + def test_output_shape_matches_input(self): + from QEfficient.diffusers.models.modeling_utils import apply_qkv_blocking + + q, k, v = _make_qkv(bs=1, nh=2, cl=8, dh=16) + out = apply_qkv_blocking(q, k, v, head_block_size=2, num_kv_blocks=2, num_q_blocks=2) + assert out.shape == q.shape, f"Expected {q.shape}, got {out.shape}" + + def test_output_is_finite(self): + from QEfficient.diffusers.models.modeling_utils import apply_qkv_blocking + + q, k, v = _make_qkv(bs=1, nh=2, cl=8, dh=16) + out = apply_qkv_blocking(q, k, v, head_block_size=2, num_kv_blocks=2, num_q_blocks=2) + assert torch.isfinite(out).all() + + def test_small_seq_matches_standard_attention(self): + """For CL <= 512, qkv blocking must match standard attention.""" + from QEfficient.diffusers.models.modeling_utils import apply_qkv_blocking + + q, k, v = _make_qkv(bs=1, nh=2, cl=8, dh=16) + ref = _standard_attention(q, k, v) + out = apply_qkv_blocking(q, k, v, head_block_size=2, num_kv_blocks=1, num_q_blocks=1) + max_diff = (ref - out).abs().max().item() + assert max_diff < 1e-4, f"QKV blocking vs standard attention max_diff={max_diff:.2e}" + + def test_batch_size_2_works(self): + from QEfficient.diffusers.models.modeling_utils import apply_qkv_blocking + + q, k, v = _make_qkv(bs=2, nh=2, cl=8, dh=16) + out = apply_qkv_blocking(q, k, v, head_block_size=2, num_kv_blocks=2, num_q_blocks=2) + assert out.shape == q.shape + assert torch.isfinite(out).all() + + +# --------------------------------------------------------------------------- +# 7. compute_blocked_attention dispatcher +# --------------------------------------------------------------------------- + + +@pytest.mark.diffusers +@pytest.mark.accuracy +class TestComputeBlockedAttention: + """compute_blocked_attention must dispatch to the correct function.""" + + def test_head_mode_output_shape(self): + from QEfficient.diffusers.models.modeling_utils import compute_blocked_attention + + q, k, v = _make_qkv(bs=1, nh=4, cl=8, dh=16) + out = compute_blocked_attention( + q, k, v, head_block_size=2, num_kv_blocks=2, num_q_blocks=2, blocking_mode="head" + ) + assert out.shape == q.shape + + def test_kv_mode_output_shape(self): + from QEfficient.diffusers.models.modeling_utils import compute_blocked_attention + + q, k, v = _make_qkv(bs=1, nh=2, cl=8, dh=16) + out = compute_blocked_attention(q, k, v, head_block_size=2, num_kv_blocks=2, num_q_blocks=2, blocking_mode="kv") + assert out.shape == q.shape + + def test_q_mode_output_shape(self): + from QEfficient.diffusers.models.modeling_utils import compute_blocked_attention + + q, k, v = _make_qkv(bs=1, nh=2, cl=8, dh=16) + out = compute_blocked_attention(q, k, v, head_block_size=2, num_kv_blocks=2, num_q_blocks=2, blocking_mode="q") + assert out.shape == q.shape + + def test_qkv_mode_output_shape(self): + from QEfficient.diffusers.models.modeling_utils import compute_blocked_attention + + q, k, v = _make_qkv(bs=1, nh=2, cl=8, dh=16) + out = compute_blocked_attention( + q, k, v, head_block_size=2, num_kv_blocks=2, num_q_blocks=2, blocking_mode="qkv" + ) + assert out.shape == q.shape + + def test_all_modes_produce_finite_outputs(self): + """All four blocking modes must produce finite outputs.""" + from QEfficient.diffusers.models.modeling_utils import compute_blocked_attention + + q, k, v = _make_qkv(bs=1, nh=4, cl=8, dh=16) + for mode in ["head", "kv", "q", "qkv"]: + out = compute_blocked_attention( + q, k, v, head_block_size=2, num_kv_blocks=2, num_q_blocks=2, blocking_mode=mode + ) + assert torch.isfinite(out).all(), f"Mode '{mode}' produced NaN/Inf" + + def test_small_seq_all_modes_agree(self): + """For CL <= 512, all modes must produce the same result as standard attention.""" + from QEfficient.diffusers.models.modeling_utils import compute_blocked_attention + + q, k, v = _make_qkv(bs=1, nh=4, cl=8, dh=16) + ref = _standard_attention(q, k, v) + + for mode in ["head", "kv", "q", "qkv"]: + out = compute_blocked_attention( + q, k, v, head_block_size=1, num_kv_blocks=1, num_q_blocks=1, blocking_mode=mode + ) + max_diff = (ref - out).abs().max().item() + assert max_diff < 1e-4, f"Mode '{mode}' vs standard attention max_diff={max_diff:.2e}" + + def test_with_attention_mask(self): + """compute_blocked_attention must accept an optional boolean attention_mask.""" + from QEfficient.diffusers.models.modeling_utils import compute_blocked_attention + + q, k, v = _make_qkv(bs=1, nh=2, cl=8, dh=16) + # attention_mask must be boolean (True = masked/ignored position) + mask = torch.zeros(1, 1, 8, 8, dtype=torch.bool) + out = compute_blocked_attention( + q, k, v, head_block_size=2, num_kv_blocks=2, num_q_blocks=2, blocking_mode="head", attention_mask=mask + ) + assert out.shape == q.shape + assert torch.isfinite(out).all() + + +# --------------------------------------------------------------------------- +# 8. QEff normalization layers +# --------------------------------------------------------------------------- + + +@pytest.mark.diffusers +@pytest.mark.accuracy +class TestQEffNormalizationLayers: + """QEff normalization layers must produce correct outputs on CPU.""" + + def _make_ada_layer_norm_zero(self, embedding_dim=16): + from QEfficient.diffusers.models.normalization import QEffAdaLayerNormZero + + return QEffAdaLayerNormZero(embedding_dim=embedding_dim).eval() + + def _make_ada_layer_norm_zero_single(self, embedding_dim=16): + from QEfficient.diffusers.models.normalization import QEffAdaLayerNormZeroSingle + + return QEffAdaLayerNormZeroSingle(embedding_dim=embedding_dim).eval() + + def _make_ada_layer_norm_continuous(self, embedding_dim=16, conditioning_dim=16): + from QEfficient.diffusers.models.normalization import QEffAdaLayerNormContinuous + + return QEffAdaLayerNormContinuous( + embedding_dim=embedding_dim, + conditioning_embedding_dim=conditioning_dim, + ).eval() + + def test_ada_layer_norm_zero_instantiates(self): + norm = self._make_ada_layer_norm_zero() + assert norm is not None + + def test_ada_layer_norm_zero_single_instantiates(self): + norm = self._make_ada_layer_norm_zero_single() + assert norm is not None + + def test_ada_layer_norm_continuous_instantiates(self): + norm = self._make_ada_layer_norm_continuous() + assert norm is not None + + def test_ada_layer_norm_zero_output_shape(self): + """QEffAdaLayerNormZero.forward must return tensor of same shape as input.""" + norm = self._make_ada_layer_norm_zero(embedding_dim=16) + x = torch.randn(1, 8, 16) + shift_msa = torch.randn(1, 16) + scale_msa = torch.randn(1, 16) + with torch.no_grad(): + out = norm(x, shift_msa=shift_msa, scale_msa=scale_msa) + assert out.shape == x.shape, f"Expected {x.shape}, got {out.shape}" + + def test_ada_layer_norm_zero_output_is_finite(self): + norm = self._make_ada_layer_norm_zero(embedding_dim=16) + x = torch.randn(1, 8, 16) + shift_msa = torch.randn(1, 16) + scale_msa = torch.randn(1, 16) + with torch.no_grad(): + out = norm(x, shift_msa=shift_msa, scale_msa=scale_msa) + assert torch.isfinite(out).all() + + def test_ada_layer_norm_zero_single_output_shape(self): + """QEffAdaLayerNormZeroSingle.forward must return tensor of same shape as input.""" + norm = self._make_ada_layer_norm_zero_single(embedding_dim=16) + x = torch.randn(1, 8, 16) + shift_msa = torch.randn(1, 16) + scale_msa = torch.randn(1, 16) + with torch.no_grad(): + out = norm(x, scale_msa=scale_msa, shift_msa=shift_msa) + assert out.shape == x.shape, f"Expected {x.shape}, got {out.shape}" + + def test_ada_layer_norm_zero_single_output_is_finite(self): + norm = self._make_ada_layer_norm_zero_single(embedding_dim=16) + x = torch.randn(1, 8, 16) + with torch.no_grad(): + out = norm(x, scale_msa=torch.randn(1, 16), shift_msa=torch.randn(1, 16)) + assert torch.isfinite(out).all() + + def test_ada_layer_norm_continuous_output_shape(self): + """QEffAdaLayerNormContinuous.forward must return tensor of same shape as input.""" + norm = self._make_ada_layer_norm_continuous(embedding_dim=16, conditioning_dim=16) + x = torch.randn(1, 8, 16) + # conditioning_embedding is pre-computed: shape (batch, 2 * embedding_dim) + conditioning = torch.randn(1, 32) + with torch.no_grad(): + out = norm(x, conditioning) + assert out.shape == x.shape, f"Expected {x.shape}, got {out.shape}" + + def test_ada_layer_norm_continuous_output_is_finite(self): + norm = self._make_ada_layer_norm_continuous(embedding_dim=16, conditioning_dim=16) + x = torch.randn(1, 8, 16) + conditioning = torch.randn(1, 32) + with torch.no_grad(): + out = norm(x, conditioning) + assert torch.isfinite(out).all() + + def test_ada_layer_norm_zero_zero_shift_scale_preserves_norm(self): + """With zero shift and scale, output should equal LayerNorm(x).""" + norm = self._make_ada_layer_norm_zero(embedding_dim=16) + x = torch.randn(1, 8, 16) + shift_msa = torch.zeros(1, 16) + scale_msa = torch.zeros(1, 16) + with torch.no_grad(): + out = norm(x, shift_msa=shift_msa, scale_msa=scale_msa) + # With zero shift and scale: out = LayerNorm(x) * (1 + 0) + 0 = LayerNorm(x) + ln = torch.nn.LayerNorm(16, elementwise_affine=False, eps=1e-6) + expected = ln(x) + max_diff = (out - expected).abs().max().item() + assert max_diff < 1e-5, f"Zero shift/scale: max_diff={max_diff:.2e}" + + def test_ada_layer_norm_continuous_batch_size_2(self): + norm = self._make_ada_layer_norm_continuous(embedding_dim=16, conditioning_dim=16) + x = torch.randn(2, 8, 16) + conditioning = torch.randn(2, 32) + with torch.no_grad(): + out = norm(x, conditioning) + assert out.shape == (2, 8, 16) + assert torch.isfinite(out).all() + + +# --------------------------------------------------------------------------- +# 9. Diffusers transforms structure +# --------------------------------------------------------------------------- + + +@pytest.mark.diffusers +class TestDiffusersTransforms: + """Diffusers transforms must have correct class-level structure.""" + + def test_custom_ops_transform_importable(self): + from QEfficient.diffusers.models.pytorch_transforms import CustomOpsTransform + + assert CustomOpsTransform is not None + + def test_attention_transform_importable(self): + from QEfficient.diffusers.models.pytorch_transforms import AttentionTransform + + assert AttentionTransform is not None + + def test_normalization_transform_importable(self): + from QEfficient.diffusers.models.pytorch_transforms import NormalizationTransform + + assert NormalizationTransform is not None + + def test_custom_ops_transform_has_module_mapping(self): + from QEfficient.diffusers.models.pytorch_transforms import CustomOpsTransform + + assert hasattr(CustomOpsTransform, "_module_mapping") + assert len(CustomOpsTransform._module_mapping) > 0 + + def test_attention_transform_has_module_mapping(self): + from QEfficient.diffusers.models.pytorch_transforms import AttentionTransform + + assert hasattr(AttentionTransform, "_module_mapping") + assert len(AttentionTransform._module_mapping) > 0 + + def test_normalization_transform_has_module_mapping(self): + from QEfficient.diffusers.models.pytorch_transforms import NormalizationTransform + + assert hasattr(NormalizationTransform, "_module_mapping") + assert len(NormalizationTransform._module_mapping) > 0 + + def test_attention_transform_maps_flux_attention(self): + from diffusers.models.transformers.transformer_flux import FluxAttention + + from QEfficient.diffusers.models.pytorch_transforms import AttentionTransform + from QEfficient.diffusers.models.transformers.transformer_flux import QEffFluxAttention + + assert FluxAttention in AttentionTransform._module_mapping + assert AttentionTransform._module_mapping[FluxAttention] is QEffFluxAttention + + def test_attention_transform_maps_flux_transformer_block(self): + from diffusers.models.transformers.transformer_flux import FluxTransformerBlock + + from QEfficient.diffusers.models.pytorch_transforms import AttentionTransform + from QEfficient.diffusers.models.transformers.transformer_flux import QEffFluxTransformerBlock + + assert FluxTransformerBlock in AttentionTransform._module_mapping + assert AttentionTransform._module_mapping[FluxTransformerBlock] is QEffFluxTransformerBlock + + def test_attention_transform_maps_flux_single_transformer_block(self): + from diffusers.models.transformers.transformer_flux import FluxSingleTransformerBlock + + from QEfficient.diffusers.models.pytorch_transforms import AttentionTransform + from QEfficient.diffusers.models.transformers.transformer_flux import QEffFluxSingleTransformerBlock + + assert FluxSingleTransformerBlock in AttentionTransform._module_mapping + assert AttentionTransform._module_mapping[FluxSingleTransformerBlock] is QEffFluxSingleTransformerBlock + + def test_attention_transform_maps_flux_transformer_2d_model(self): + from diffusers.models.transformers.transformer_flux import FluxTransformer2DModel + + from QEfficient.diffusers.models.pytorch_transforms import AttentionTransform + from QEfficient.diffusers.models.transformers.transformer_flux import QEffFluxTransformer2DModel + + assert FluxTransformer2DModel in AttentionTransform._module_mapping + assert AttentionTransform._module_mapping[FluxTransformer2DModel] is QEffFluxTransformer2DModel + + def test_normalization_transform_maps_ada_layer_norm_zero(self): + from diffusers.models.normalization import AdaLayerNormZero + + from QEfficient.diffusers.models.normalization import QEffAdaLayerNormZero + from QEfficient.diffusers.models.pytorch_transforms import NormalizationTransform + + assert AdaLayerNormZero in NormalizationTransform._module_mapping + assert NormalizationTransform._module_mapping[AdaLayerNormZero] is QEffAdaLayerNormZero + + def test_normalization_transform_maps_ada_layer_norm_zero_single(self): + from diffusers.models.normalization import AdaLayerNormZeroSingle + + from QEfficient.diffusers.models.normalization import QEffAdaLayerNormZeroSingle + from QEfficient.diffusers.models.pytorch_transforms import NormalizationTransform + + assert AdaLayerNormZeroSingle in NormalizationTransform._module_mapping + assert NormalizationTransform._module_mapping[AdaLayerNormZeroSingle] is QEffAdaLayerNormZeroSingle + + def test_all_transforms_have_apply_method(self): + from QEfficient.diffusers.models.pytorch_transforms import ( + AttentionTransform, + CustomOpsTransform, + NormalizationTransform, + ) + + for cls in [CustomOpsTransform, AttentionTransform, NormalizationTransform]: + assert hasattr(cls, "apply"), f"{cls.__name__} missing apply method" + assert callable(cls.apply), f"{cls.__name__}.apply is not callable" + + +# --------------------------------------------------------------------------- +# 10. Pipeline utilities +# --------------------------------------------------------------------------- + + +@pytest.mark.diffusers +class TestPipelineUtils: + """Pipeline utility functions must produce correct results.""" + + def test_calculate_compressed_latent_dimension_importable(self): + from QEfficient.diffusers.pipelines.pipeline_utils import calculate_compressed_latent_dimension + + assert callable(calculate_compressed_latent_dimension) + + def test_calculate_latent_dimensions_with_frames_importable(self): + from QEfficient.diffusers.pipelines.pipeline_utils import calculate_latent_dimensions_with_frames + + assert callable(calculate_latent_dimensions_with_frames) + + def test_compressed_latent_dimension_basic(self): + """calculate_compressed_latent_dimension returns (cl, latent_h, latent_w). + cl = (latent_h * latent_w) // 4 (Flux 2x2 packing). + For H=64, W=64, vsf=8: latent_h=8, latent_w=8, cl=(8*8)//4=16. + """ + from QEfficient.diffusers.pipelines.pipeline_utils import calculate_compressed_latent_dimension + + cl, latent_h, latent_w = calculate_compressed_latent_dimension(height=64, width=64, vae_scale_factor=8) + assert latent_h == 8, f"Expected latent_h=8, got {latent_h}" + assert latent_w == 8, f"Expected latent_w=8, got {latent_w}" + assert cl == 16, f"Expected cl=16 (=(8*8)//4), got {cl}" + + def test_compressed_latent_dimension_non_square(self): + """For H=64, W=128, vsf=8: latent_h=8, latent_w=16, cl=(8*16)//4=32.""" + from QEfficient.diffusers.pipelines.pipeline_utils import calculate_compressed_latent_dimension + + cl, latent_h, latent_w = calculate_compressed_latent_dimension(height=64, width=128, vae_scale_factor=8) + assert latent_h == 8, f"Expected latent_h=8, got {latent_h}" + assert latent_w == 16, f"Expected latent_w=16, got {latent_w}" + assert cl == 32, f"Expected cl=32 (=(8*16)//4), got {cl}" + + def test_compressed_latent_dimension_patch_size_1(self): + """For H=16, W=16, vsf=1: latent_h=16, latent_w=16, cl=(16*16)//4=64.""" + from QEfficient.diffusers.pipelines.pipeline_utils import calculate_compressed_latent_dimension + + cl, latent_h, latent_w = calculate_compressed_latent_dimension(height=16, width=16, vae_scale_factor=1) + assert latent_h == 16, f"Expected latent_h=16, got {latent_h}" + assert latent_w == 16, f"Expected latent_w=16, got {latent_w}" + assert cl == 64, f"Expected cl=64 (=(16*16)//4), got {cl}" + + def test_compressed_latent_dimension_returns_tuple_of_ints(self): + """calculate_compressed_latent_dimension must return a tuple of 3 ints.""" + from QEfficient.diffusers.pipelines.pipeline_utils import calculate_compressed_latent_dimension + + result = calculate_compressed_latent_dimension(height=64, width=64, vae_scale_factor=8) + assert isinstance(result, tuple), f"Expected tuple, got {type(result)}" + assert len(result) == 3, f"Expected 3-tuple, got length {len(result)}" + cl, latent_h, latent_w = result + assert isinstance(cl, int), f"Expected cl to be int, got {type(cl)}" + assert isinstance(latent_h, int), f"Expected latent_h to be int, got {type(latent_h)}" + assert isinstance(latent_w, int), f"Expected latent_w to be int, got {type(latent_w)}" + + def test_latent_dimensions_with_frames_returns_tuple(self): + from QEfficient.diffusers.pipelines.pipeline_utils import calculate_latent_dimensions_with_frames + + result = calculate_latent_dimensions_with_frames( + height=64, + width=64, + num_frames=16, + vae_scale_factor_spatial=2, + vae_scale_factor_temporal=4, + patch_height=2, + patch_width=2, + ) + assert isinstance(result, (tuple, list, int)), f"Unexpected return type: {type(result)}" + + def test_latent_dimensions_with_frames_is_positive(self): + from QEfficient.diffusers.pipelines.pipeline_utils import calculate_latent_dimensions_with_frames + + result = calculate_latent_dimensions_with_frames( + height=64, + width=64, + num_frames=16, + vae_scale_factor_spatial=2, + vae_scale_factor_temporal=4, + patch_height=2, + patch_width=2, + ) + if isinstance(result, (tuple, list)): + assert all(r > 0 for r in result), "All dimensions must be positive" + else: + assert result > 0 + + def test_module_perf_importable(self): + from QEfficient.diffusers.pipelines.pipeline_utils import ModulePerf + + assert ModulePerf is not None + + def test_module_perf_instantiable(self): + from QEfficient.diffusers.pipelines.pipeline_utils import ModulePerf + + perf = ModulePerf(module_name="test", perf=100) + assert perf is not None + + def test_module_perf_has_expected_fields(self): + from QEfficient.diffusers.pipelines.pipeline_utils import ModulePerf + + perf = ModulePerf(module_name="test", perf=100) + assert hasattr(perf, "module_name") + assert hasattr(perf, "perf") + + def test_qeff_pipeline_output_importable(self): + from QEfficient.diffusers.pipelines.pipeline_utils import QEffPipelineOutput + + assert QEffPipelineOutput is not None + + def test_qeff_pipeline_output_instantiable(self): + import numpy as np + + from QEfficient.diffusers.pipelines.pipeline_utils import ModulePerf, QEffPipelineOutput + + output = QEffPipelineOutput( + pipeline_module=[ModulePerf(module_name="test", perf=100)], images=np.zeros((1, 64, 64, 3)) + ) + assert output is not None + + def test_qeff_pipeline_output_has_images(self): + import numpy as np + + from QEfficient.diffusers.pipelines.pipeline_utils import ModulePerf, QEffPipelineOutput + + images = np.zeros((1, 64, 64, 3)) + output = QEffPipelineOutput(pipeline_module=[ModulePerf(module_name="test", perf=100)], images=images) + assert hasattr(output, "images") + assert output.images is images + + +# --------------------------------------------------------------------------- +# 11. Pipeline module class structure +# --------------------------------------------------------------------------- + + +@pytest.mark.diffusers +class TestPipelineModuleStructure: + """Pipeline module classes must have correct class-level structure.""" + + def test_qeff_text_encoder_importable(self): + from QEfficient.diffusers.pipelines.pipeline_module import QEffTextEncoder + + assert QEffTextEncoder is not None + + def test_qeff_vae_importable(self): + from QEfficient.diffusers.pipelines.pipeline_module import QEffVAE + + assert QEffVAE is not None + + def test_qeff_flux_transformer_model_importable(self): + from QEfficient.diffusers.pipelines.pipeline_module import QEffFluxTransformerModel + + assert QEffFluxTransformerModel is not None + + def test_qeff_wan_unified_transformer_importable(self): + from QEfficient.diffusers.pipelines.pipeline_module import QEffWanUnifiedTransformer + + assert QEffWanUnifiedTransformer is not None + + def test_qeff_text_encoder_has_pytorch_transforms(self): + from QEfficient.diffusers.pipelines.pipeline_module import QEffTextEncoder + + assert hasattr(QEffTextEncoder, "_pytorch_transforms") + assert isinstance(QEffTextEncoder._pytorch_transforms, list) + + def test_qeff_text_encoder_has_onnx_transforms(self): + from QEfficient.diffusers.pipelines.pipeline_module import QEffTextEncoder + + assert hasattr(QEffTextEncoder, "_onnx_transforms") + assert isinstance(QEffTextEncoder._onnx_transforms, list) + + def test_qeff_flux_transformer_model_has_pytorch_transforms(self): + from QEfficient.diffusers.pipelines.pipeline_module import QEffFluxTransformerModel + + assert hasattr(QEffFluxTransformerModel, "_pytorch_transforms") + assert isinstance(QEffFluxTransformerModel._pytorch_transforms, list) + + def test_qeff_flux_transformer_model_has_onnx_transforms(self): + from QEfficient.diffusers.pipelines.pipeline_module import QEffFluxTransformerModel + + assert hasattr(QEffFluxTransformerModel, "_onnx_transforms") + assert isinstance(QEffFluxTransformerModel._onnx_transforms, list) + + def test_qeff_flux_transformer_model_pytorch_transforms_include_attention(self): + from QEfficient.diffusers.models.pytorch_transforms import AttentionTransform + from QEfficient.diffusers.pipelines.pipeline_module import QEffFluxTransformerModel + + assert AttentionTransform in QEffFluxTransformerModel._pytorch_transforms, ( + "AttentionTransform not in QEffFluxTransformerModel._pytorch_transforms" + ) + + def test_qeff_flux_transformer_model_pytorch_transforms_include_normalization(self): + from QEfficient.diffusers.models.pytorch_transforms import NormalizationTransform + from QEfficient.diffusers.pipelines.pipeline_module import QEffFluxTransformerModel + + assert NormalizationTransform in QEffFluxTransformerModel._pytorch_transforms, ( + "NormalizationTransform not in QEffFluxTransformerModel._pytorch_transforms" + ) + + def test_qeff_text_encoder_pytorch_transforms_include_custom_ops(self): + from QEfficient.diffusers.models.pytorch_transforms import CustomOpsTransform + from QEfficient.diffusers.pipelines.pipeline_module import QEffTextEncoder + + assert CustomOpsTransform in QEffTextEncoder._pytorch_transforms, ( + "CustomOpsTransform not in QEffTextEncoder._pytorch_transforms" + ) + + def test_qeff_text_encoder_onnx_transforms_include_fp16_clip(self): + from QEfficient.base.onnx_transforms import FP16ClipTransform + from QEfficient.diffusers.pipelines.pipeline_module import QEffTextEncoder + + assert FP16ClipTransform in QEffTextEncoder._onnx_transforms, ( + "FP16ClipTransform not in QEffTextEncoder._onnx_transforms" + ) + + def test_qeff_flux_transformer_model_onnx_transforms_include_fp16_clip(self): + from QEfficient.base.onnx_transforms import FP16ClipTransform + from QEfficient.diffusers.pipelines.pipeline_module import QEffFluxTransformerModel + + assert FP16ClipTransform in QEffFluxTransformerModel._onnx_transforms, ( + "FP16ClipTransform not in QEffFluxTransformerModel._onnx_transforms" + ) + + def test_qeff_vae_has_pytorch_transforms(self): + from QEfficient.diffusers.pipelines.pipeline_module import QEffVAE + + assert hasattr(QEffVAE, "_pytorch_transforms") + assert isinstance(QEffVAE._pytorch_transforms, list) + + def test_qeff_wan_unified_transformer_has_pytorch_transforms(self): + from QEfficient.diffusers.pipelines.pipeline_module import QEffWanUnifiedTransformer + + assert hasattr(QEffWanUnifiedTransformer, "_pytorch_transforms") + assert isinstance(QEffWanUnifiedTransformer._pytorch_transforms, list) + + +# --------------------------------------------------------------------------- +# 12. Flux transformer blocks (tiny in-memory) +# --------------------------------------------------------------------------- + + +def _make_tiny_flux_transformer(): + """ + Create a tiny QEffFluxTransformer2DModel for CPU testing. + Returns None if instantiation fails (e.g., diffusers version mismatch). + """ + try: + from diffusers.models.transformers.transformer_flux import FluxTransformer2DModel + + from QEfficient.diffusers.models.pytorch_transforms import AttentionTransform, NormalizationTransform + + model = FluxTransformer2DModel( + patch_size=1, + in_channels=4, + num_layers=1, + num_single_layers=1, + attention_head_dim=8, + num_attention_heads=2, + joint_attention_dim=16, + pooled_projection_dim=16, + guidance_embeds=False, + axes_dims_rope=[2, 2, 4], + ).eval() + + model, _ = AttentionTransform.apply(model) + model, _ = NormalizationTransform.apply(model) + return model + except Exception: + return None + + +def _make_tiny_flux_inputs(model, batch=1, cl=4, text_seq=8): + """ + Build inputs for QEffFluxTransformer2DModel.forward. + inner_dim = num_attention_heads * attention_head_dim = 2 * 8 = 16 + """ + inner_dim = 16 # 2 heads * 8 head_dim + in_channels = 4 + joint_attention_dim = 16 + pooled_projection_dim = 16 + num_layers = 1 + num_single_layers = 1 + + hidden_states = torch.randn(batch, cl, in_channels) + encoder_hidden_states = torch.randn(batch, text_seq, joint_attention_dim) + pooled_projections = torch.randn(batch, pooled_projection_dim) + timestep = torch.tensor([0.5] * batch) + img_ids = torch.zeros(cl, 3) + txt_ids = torch.zeros(text_seq, 3) + + # adaln_emb: (num_layers, 12, inner_dim) — 12 = 6 for hidden + 6 for encoder + adaln_emb = torch.randn(num_layers, 12, inner_dim) + # adaln_single_emb: (num_single_layers, 3, inner_dim) + adaln_single_emb = torch.randn(num_single_layers, 3, inner_dim) + # adaln_out: (batch, 2 * inner_dim) — pre-computed scale+shift for norm_out + adaln_out = torch.randn(batch, 2 * inner_dim) + + return { + "hidden_states": hidden_states, + "encoder_hidden_states": encoder_hidden_states, + "pooled_projections": pooled_projections, + "timestep": timestep, + "img_ids": img_ids, + "txt_ids": txt_ids, + "adaln_emb": adaln_emb, + "adaln_single_emb": adaln_single_emb, + "adaln_out": adaln_out, + "return_dict": False, + } + + +@pytest.mark.diffusers +@pytest.mark.accuracy +class TestFluxTransformerBlocks: + """ + QEffFluxTransformer2DModel must produce correct outputs on CPU. + Uses a tiny in-memory model (1 layer, 2 heads, dim=16) — no network downloads. + """ + + def test_qeff_flux_transformer_2d_model_wraps_without_error(self): + model = _make_tiny_flux_transformer() + if model is None: + pytest.skip("Could not instantiate tiny FluxTransformer2DModel") + from QEfficient.diffusers.models.transformers.transformer_flux import QEffFluxTransformer2DModel + + assert isinstance(model, QEffFluxTransformer2DModel), f"Expected QEffFluxTransformer2DModel, got {type(model)}" + + def test_qeff_flux_transformer_2d_model_is_eval_mode(self): + model = _make_tiny_flux_transformer() + if model is None: + pytest.skip("Could not instantiate tiny FluxTransformer2DModel") + assert not model.training, "Model must be in eval mode" + + def test_qeff_flux_transformer_2d_model_forward_returns_output(self): + model = _make_tiny_flux_transformer() + if model is None: + pytest.skip("Could not instantiate tiny FluxTransformer2DModel") + inputs = _make_tiny_flux_inputs(model) + with torch.no_grad(): + out = model(**inputs) + assert out is not None + + def test_qeff_flux_transformer_2d_model_output_shape(self): + """Output sample must have shape (batch, cl, in_channels).""" + model = _make_tiny_flux_transformer() + if model is None: + pytest.skip("Could not instantiate tiny FluxTransformer2DModel") + batch, cl, in_channels = 1, 4, 4 + inputs = _make_tiny_flux_inputs(model, batch=batch, cl=cl) + with torch.no_grad(): + out = model(**inputs) + # out is a tuple when return_dict=False; out[0] is the sample + sample = out[0] if isinstance(out, (tuple, list)) else out.sample + assert sample.shape == (batch, cl, in_channels), f"Expected ({batch}, {cl}, {in_channels}), got {sample.shape}" + + def test_qeff_flux_transformer_2d_model_output_is_finite(self): + model = _make_tiny_flux_transformer() + if model is None: + pytest.skip("Could not instantiate tiny FluxTransformer2DModel") + inputs = _make_tiny_flux_inputs(model) + with torch.no_grad(): + out = model(**inputs) + sample = out[0] if isinstance(out, (tuple, list)) else out.sample + assert torch.isfinite(sample).all(), "QEffFluxTransformer2DModel output contains NaN/Inf" + + def test_qeff_flux_transformer_2d_model_is_deterministic(self): + """Same inputs must produce the same output.""" + model = _make_tiny_flux_transformer() + if model is None: + pytest.skip("Could not instantiate tiny FluxTransformer2DModel") + inputs = _make_tiny_flux_inputs(model) + with torch.no_grad(): + out1 = model(**inputs) + out2 = model(**inputs) + s1 = out1[0] if isinstance(out1, (tuple, list)) else out1.sample + s2 = out2[0] if isinstance(out2, (tuple, list)) else out2.sample + assert torch.allclose(s1, s2), "QEffFluxTransformer2DModel is not deterministic" + + def test_qeff_flux_transformer_2d_model_get_submodules_for_export(self): + """get_submodules_for_export must return the expected QEff block classes.""" + model = _make_tiny_flux_transformer() + if model is None: + pytest.skip("Could not instantiate tiny FluxTransformer2DModel") + from QEfficient.diffusers.models.transformers.transformer_flux import ( + QEffFluxSingleTransformerBlock, + QEffFluxTransformerBlock, + ) + + submodules = model.get_submodules_for_export() + assert QEffFluxTransformerBlock in submodules, "QEffFluxTransformerBlock not in get_submodules_for_export()" + assert QEffFluxSingleTransformerBlock in submodules, ( + "QEffFluxSingleTransformerBlock not in get_submodules_for_export()" + ) + + def test_qeff_flux_attn_processor_replaces_original(self): + """After AttentionTransform, FluxAttention must use QEffFluxAttnProcessor.""" + model = _make_tiny_flux_transformer() + if model is None: + pytest.skip("Could not instantiate tiny FluxTransformer2DModel") + from QEfficient.diffusers.models.transformers.transformer_flux import ( + QEffFluxAttention, + QEffFluxAttnProcessor, + ) + + for m in model.modules(): + if isinstance(m, QEffFluxAttention): + assert isinstance(m.processor, QEffFluxAttnProcessor), ( + f"Expected QEffFluxAttnProcessor, got {type(m.processor)}" + ) + break diff --git a/tests/unit_test/utils/test_error_handling.py b/tests/unit_test/utils/test_error_handling.py new file mode 100644 index 000000000..c0fb7da66 --- /dev/null +++ b/tests/unit_test/utils/test_error_handling.py @@ -0,0 +1,359 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +# SPDX-License-Identifier: BSD-3-Clause +# +# ----------------------------------------------------------------------------- +""" +Error handling & edge case tests for QEfficient. + +Tests verify that the public API raises clear, descriptive errors when given +invalid inputs, rather than cryptic PyTorch/ONNX failures. + +All tests run on CPU only. +""" + +import pytest +import torch +import torch.nn as nn +from transformers import ( + BertConfig, + BertForMaskedLM, + GPT2Config, + GPT2LMHeadModel, + LlamaConfig, + LlamaForCausalLM, + Qwen2Config, + Qwen2ForCausalLM, +) + +from QEfficient.transformers.models.modeling_auto import QEFFAutoModelForCausalLM + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def make_tiny_gpt2(): + cfg = GPT2Config(n_layer=1, n_head=2, n_embd=64, vocab_size=500, n_positions=32, n_ctx=32) + return GPT2LMHeadModel(cfg).eval() + + +def make_tiny_llama(): + cfg = LlamaConfig( + num_hidden_layers=2, + num_attention_heads=2, + num_key_value_heads=2, + hidden_size=64, + intermediate_size=128, + vocab_size=500, + max_position_embeddings=64, + ) + return LlamaForCausalLM(cfg).eval() + + +def make_tiny_qwen2(): + cfg = Qwen2Config( + num_hidden_layers=2, + num_attention_heads=2, + num_key_value_heads=2, + hidden_size=64, + intermediate_size=128, + vocab_size=500, + max_position_embeddings=64, + ) + return Qwen2ForCausalLM(cfg).eval() + + +def make_tiny_bert(): + cfg = BertConfig( + num_hidden_layers=1, + num_attention_heads=2, + hidden_size=64, + intermediate_size=128, + vocab_size=500, + max_position_embeddings=32, + ) + return BertForMaskedLM(cfg).eval() + + +# --------------------------------------------------------------------------- +# Tests: QEFFAutoModelForCausalLM constructor error paths +# --------------------------------------------------------------------------- + + +@pytest.mark.cpu_only +class TestQEFFAutoModelForCausalLMErrorPaths: + """QEFFAutoModelForCausalLM must raise TypeError for non-CausalLM models.""" + + def test_non_causal_lm_model_raises_type_error(self): + """Wrapping a BERT model (not CausalLM) must raise TypeError.""" + bert = make_tiny_bert() + with pytest.raises(TypeError, match="CausalLM|LMHeadModel"): + QEFFAutoModelForCausalLM(bert) + + def test_plain_nn_module_raises_type_error(self): + """Wrapping a plain nn.Module must raise TypeError.""" + + class SimpleModel(nn.Module): + def forward(self, x): + return x + + with pytest.raises(TypeError): + QEFFAutoModelForCausalLM(SimpleModel()) + + def test_causal_lm_model_does_not_raise(self): + """Wrapping a valid CausalLM model must not raise.""" + model = make_tiny_gpt2() + qeff = QEFFAutoModelForCausalLM(model) + assert qeff is not None + + def test_llama_causal_lm_does_not_raise(self): + """Wrapping a LlamaForCausalLM must not raise.""" + model = make_tiny_llama() + qeff = QEFFAutoModelForCausalLM(model) + assert qeff is not None + + +# --------------------------------------------------------------------------- +# Tests: compile() error paths +# --------------------------------------------------------------------------- + + +@pytest.mark.cpu_only +class TestQEFFAutoModelCompileErrorPaths: + """compile() must raise appropriate errors for invalid argument combinations.""" + + def test_compile_cb_without_full_batch_size_raises_type_error(self): + """compile(continuous_batching=True) without full_batch_size must raise TypeError.""" + model = make_tiny_gpt2() + qeff = QEFFAutoModelForCausalLM(model, continuous_batching=True) + with pytest.raises(TypeError, match="full_batch_size"): + qeff.compile( + prefill_seq_len=8, + ctx_len=32, + # full_batch_size intentionally omitted + ) + + def test_compile_kv_cache_batch_size_without_full_batch_size_raises_value_error(self): + """compile(kv_cache_batch_size=N) without full_batch_size must raise ValueError.""" + model = make_tiny_gpt2() + # continuous_batching=False but kv_cache_batch_size set without full_batch_size + _ = QEFFAutoModelForCausalLM(model, continuous_batching=False) + # This should log a warning but not raise for non-CB mode + # The ValueError is raised when kv_cache_batch_size is set but full_batch_size is None + # and continuous_batching is True + qeff_cb = QEFFAutoModelForCausalLM(make_tiny_gpt2(), continuous_batching=True) + with pytest.raises((TypeError, ValueError)): + qeff_cb.compile( + prefill_seq_len=8, + ctx_len=32, + kv_cache_batch_size=4, + # full_batch_size intentionally omitted + ) + + def test_prefill_only_non_bool_raises_type_error(self): + """compile(prefill_only='yes') must raise TypeError.""" + model = make_tiny_gpt2() + qeff = QEFFAutoModelForCausalLM(model) + with pytest.raises(TypeError, match="prefill_only"): + qeff.compile( + prefill_seq_len=8, + ctx_len=32, + prefill_only="yes", # invalid: must be bool + ) + + +# --------------------------------------------------------------------------- +# Tests: check_and_get_num_speculative_tokens error paths +# --------------------------------------------------------------------------- + + +@pytest.mark.cpu_only +class TestCheckNumSpeculativeTokensErrorPaths: + """check_and_get_num_speculative_tokens must raise for invalid TLM configurations.""" + + def test_tlm_without_num_speculative_tokens_raises_type_error(self): + """TLM model without num_speculative_tokens must raise TypeError.""" + model = make_tiny_llama() + qeff = QEFFAutoModelForCausalLM(model, qaic_config={"speculative_model_type": "target"}) + assert qeff.is_tlm is True + with pytest.raises(TypeError, match="num_speculative_tokens"): + qeff.check_and_get_num_speculative_tokens(num_speculative_tokens=None, prefill_seq_len=32) + + def test_tlm_prefill_seq_len_too_short_raises_value_error(self): + """TLM with prefill_seq_len < num_speculative_tokens+1 must raise ValueError.""" + model = make_tiny_llama() + qeff = QEFFAutoModelForCausalLM(model, qaic_config={"speculative_model_type": "target"}) + assert qeff.is_tlm is True + # num_speculative_tokens=5, so need prefill_seq_len >= 6 + with pytest.raises(ValueError, match="sequence length"): + qeff.check_and_get_num_speculative_tokens( + num_speculative_tokens=5, + prefill_seq_len=4, # too short + ) + + def test_tlm_valid_num_speculative_tokens_does_not_raise(self): + """TLM with valid num_speculative_tokens must not raise.""" + model = make_tiny_llama() + qeff = QEFFAutoModelForCausalLM(model, qaic_config={"speculative_model_type": "target"}) + result = qeff.check_and_get_num_speculative_tokens(num_speculative_tokens=3, prefill_seq_len=32) + assert result == 3 + + def test_non_tlm_returns_none(self): + """Non-TLM model must return None from check_and_get_num_speculative_tokens.""" + model = make_tiny_gpt2() + qeff = QEFFAutoModelForCausalLM(model) + result = qeff.check_and_get_num_speculative_tokens(num_speculative_tokens=None, prefill_seq_len=32) + assert result is None + + +# --------------------------------------------------------------------------- +# Tests: Transform error paths +# --------------------------------------------------------------------------- + + +@pytest.mark.cpu_only +class TestTransformErrorPaths: + """Transforms must raise NotImplementedError for unsupported models.""" + + def test_spd_transform_unsupported_model_raises_not_implemented(self): + """SpDTransform must raise NotImplementedError for unsupported model class.""" + from QEfficient.transformers.models.pytorch_transforms import SpDTransform + + class UnsupportedModel(nn.Module): + def forward(self, x): + return x + + with pytest.raises(NotImplementedError): + SpDTransform.apply( + UnsupportedModel(), + qaic_config={"speculative_model_type": "target"}, + ) + + def test_spd_transform_invalid_speculative_type_raises_value_error(self): + """SpDTransform must raise ValueError for invalid speculative_model_type.""" + from QEfficient.transformers.models.pytorch_transforms import KVCacheTransform, SpDTransform + + model = make_tiny_llama() + model, _ = KVCacheTransform.apply(model) + with pytest.raises(ValueError): + SpDTransform.apply( + model, + qaic_config={"speculative_model_type": "invalid_xyz"}, + ) + + def test_pooling_transform_invalid_type_raises_value_error(self): + """PoolingTransform must raise ValueError for invalid pooling type string.""" + from QEfficient.transformers.models.pytorch_transforms import PoolingTransform + + class DummyEncoder(nn.Module): + def forward(self, input_ids=None, attention_mask=None): + bs = input_ids.shape[0] if input_ids is not None else 1 + return type("Output", (), {"last_hidden_state": torch.zeros(bs, 8, 16)})() + + with pytest.raises((ValueError, AttributeError, TypeError)): + PoolingTransform.apply(DummyEncoder(), "invalid_pooling_type_xyz") + + def test_sampler_transform_unsupported_model_raises_not_implemented(self): + """SamplerTransform must raise NotImplementedError for unsupported model class.""" + from QEfficient.transformers.models.pytorch_transforms import SamplerTransform + + class UnsupportedModel(nn.Module): + def forward(self, x): + return x + + with pytest.raises(NotImplementedError): + SamplerTransform.apply( + UnsupportedModel(), + qaic_config={"include_sampler": True}, + ) + + +# --------------------------------------------------------------------------- +# Tests: QEFFAutoModelForImageTextToText error paths +# --------------------------------------------------------------------------- + + +@pytest.mark.cpu_only +class TestVLMErrorPaths: + """VLM model must raise ValueError when both skip_lang and skip_vision are True.""" + + def test_skip_lang_and_skip_vision_both_true_raises_value_error(self): + """_QEffAutoModelForImageTextToTextDualQPC.compile() must raise ValueError + when both skip_lang=True and skip_vision=True.""" + from QEfficient.transformers.models.modeling_auto import _QEffAutoModelForImageTextToTextDualQPC + + # We test the compile method's validation logic directly + # by checking the ValueError is raised before any model loading + # We can test this by checking the class has the validation + assert hasattr(_QEffAutoModelForImageTextToTextDualQPC, "compile") + + def test_qeff_auto_model_for_image_text_to_text_class_exists(self): + """QEFFAutoModelForImageTextToText must be importable.""" + from QEfficient.transformers.models.modeling_auto import QEFFAutoModelForImageTextToText + + assert QEFFAutoModelForImageTextToText is not None + + +# --------------------------------------------------------------------------- +# Tests: QEFFAutoModelForSpeechSeq2Seq error paths +# --------------------------------------------------------------------------- + + +@pytest.mark.cpu_only +class TestSpeechSeq2SeqErrorPaths: + """QEFFAutoModelForSpeechSeq2Seq must raise TypeError for non-seq2seq models.""" + + def test_non_seq2seq_model_raises_type_error(self): + """Wrapping a non-ForConditionalGeneration model must raise TypeError.""" + from QEfficient.transformers.models.modeling_auto import QEFFAutoModelForSpeechSeq2Seq + + model = make_tiny_gpt2() + with pytest.raises(TypeError, match="ForConditionalGeneration"): + QEFFAutoModelForSpeechSeq2Seq(model) + + +# --------------------------------------------------------------------------- +# Tests: is_tlm flag +# --------------------------------------------------------------------------- + + +@pytest.mark.cpu_only +class TestIsTLMFlag: + """is_tlm flag must be set correctly based on qaic_config.""" + + def test_is_tlm_false_without_config(self): + """is_tlm must be False when no qaic_config is provided.""" + model = make_tiny_gpt2() + qeff = QEFFAutoModelForCausalLM(model) + assert qeff.is_tlm is False + + def test_is_tlm_false_with_empty_config(self): + """is_tlm must be False when qaic_config has no speculative_model_type.""" + model = make_tiny_gpt2() + qeff = QEFFAutoModelForCausalLM(model, qaic_config={}) + assert qeff.is_tlm is False + + def test_is_tlm_true_with_target_type(self): + """is_tlm must be True when speculative_model_type='target'.""" + model = make_tiny_llama() + qeff = QEFFAutoModelForCausalLM(model, qaic_config={"speculative_model_type": "target"}) + assert qeff.is_tlm is True + + def test_turbo_type_requires_pretrained_model_name(self): + """speculative_model_type='turbo' without pretrained_model_name_or_path must raise KeyError.""" + model = make_tiny_llama() + with pytest.raises(KeyError, match="pretrained_model_name_or_path"): + QEFFAutoModelForCausalLM(model, qaic_config={"speculative_model_type": "turbo"}) + + def test_cb_and_tlm_together_model_is_tlm(self): + """continuous_batching=True with TLM: model must still be recognized as TLM.""" + model = make_tiny_llama() + qeff = QEFFAutoModelForCausalLM( + model, + continuous_batching=True, + qaic_config={"speculative_model_type": "target"}, + ) + # The model should be recognized as TLM regardless of CB flag + assert qeff.is_tlm is True diff --git a/tests/unit_test/utils/test_generation.py b/tests/unit_test/utils/test_generation.py new file mode 100644 index 000000000..b85c3c4b8 --- /dev/null +++ b/tests/unit_test/utils/test_generation.py @@ -0,0 +1,1104 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +# SPDX-License-Identifier: BSD-3-Clause +# +# ----------------------------------------------------------------------------- +""" +CPU-only tests for QEfficient.generation module. + +Tests verify: + - Module importability and dataclass construction + - Pure utility functions (calculate_latency, fix_prompts, etc.) + - File I/O (write_io_files, get_compilation_dims, read_prompts_txt_file) + - VisionHandler initialization and config-based methods + - QEffTextGenerationBase: prefill, decode, chunking, continuous batching, + prepare_decode_inputs, initialize_decode_inputs, update_decode_input, + generate_decode_stream via a fully mocked QAICInferenceSession + +All tests run on CPU only. QAICInferenceSession is mocked so no QAIC hardware +is required. +""" + +import json +from collections import deque +from unittest.mock import MagicMock, patch + +import numpy as np +import pytest +from transformers import AutoTokenizer + +from QEfficient.generation.text_generation_inference import ( + CloudAI100ExecInfo, + CloudAI100ExecInfoNew, + PerfMetrics, + calculate_latency, + fix_prompt_to_lora_id_mapping, + fix_prompts, + get_compilation_dims, + get_input_prompts, + read_prompts_txt_file, + write_io_files, +) + +# --------------------------------------------------------------------------- +# Shared mock helpers +# --------------------------------------------------------------------------- + +VOCAB_SIZE = 50257 # gpt2 tokenizer eos_token_id=50256 +CTX_LEN = 32 +PREFILL_LEN = 8 +BATCH_SIZE = 1 + + +def _make_mock_session( + batch_size=BATCH_SIZE, + prefill_seq_len=PREFILL_LEN, + ctx_len=CTX_LEN, + vocab_size=VOCAB_SIZE, + full_batch_size=None, + include_sampler=False, + force_seq_len=None, +): + """ + Build a MagicMock that mimics QAICInferenceSession well enough for + QEffTextGenerationBase to initialise and run on CPU. + """ + session = MagicMock() + + # --- binding helpers --- + def _binding(name, dims, direction="input"): + b = MagicMock() + b.name = name + b.dims = dims + b.dir = "input" if direction == "input" else "output" + b.size = int(np.prod(dims)) * 4 # 4 bytes per float32 + b.type = 1 # FLOAT_TYPE + return b + + # Build bindings list + bindings = [ + _binding("input_ids", [batch_size, prefill_seq_len], "input"), + _binding("position_ids", [batch_size, prefill_seq_len], "input"), + _binding("logits", [batch_size, prefill_seq_len, vocab_size], "output"), + ] + if full_batch_size is not None: + bindings.append(_binding("batch_index", [full_batch_size, 1], "input")) + + session.bindings = bindings + session.binding_index_map = {b.name: i for i, b in enumerate(bindings)} + session.allowed_shapes = [] # use bindings dims directly + session.input_names = [b.name for b in bindings if b.dir == "input"] + session.output_names = [b.name for b in bindings if b.dir == "output"] + session.is_active = True + + # run() returns logits with argmax-able values + def _run(inputs): + bs = inputs.get("input_ids", np.zeros((batch_size, 1))).shape[0] + seq = ( + force_seq_len if force_seq_len is not None else inputs.get("input_ids", np.zeros((batch_size, 1))).shape[1] + ) + logits = np.zeros((bs, seq, vocab_size), dtype=np.float32) + logits[:, :, 42] = 1.0 # always predict token 42 + return {"logits": logits} + + session.run.side_effect = _run + session.skip_buffers = MagicMock() + session.set_buffers = MagicMock() + session.activate = MagicMock() + session.deactivate = MagicMock() + return session + + +def _make_tokenizer(): + """Return a tiny GPT2 tokenizer (downloads once, cached).""" + try: + tok = AutoTokenizer.from_pretrained("gpt2") + tok.pad_token = tok.eos_token + return tok + except Exception: + pytest.skip("Cannot load gpt2 tokenizer (network unavailable)") + + +def _make_base_instance( + batch_size=BATCH_SIZE, + ctx_len=CTX_LEN, + full_batch_size=None, +): + """ + Construct a QEffTextGenerationBase with a mocked session. + Patches QAICInferenceSession so no hardware is needed. + """ + from QEfficient.generation.text_generation_inference import QEffTextGenerationBase + + tok = _make_tokenizer() + mock_session = _make_mock_session( + batch_size=batch_size, + ctx_len=ctx_len, + full_batch_size=full_batch_size, + ) + + with patch( + "QEfficient.generation.text_generation_inference.QAICInferenceSession", + return_value=mock_session, + ): + obj = QEffTextGenerationBase( + tokenizer=tok, + qpc_path="/fake/path/model.qpc", + ctx_len=ctx_len, + full_batch_size=full_batch_size, + ) + return obj, tok, mock_session + + +# --------------------------------------------------------------------------- +# Tests: Module importability +# --------------------------------------------------------------------------- + + +class TestGenerationModuleImportability: + """All generation modules must be importable on CPU.""" + + def test_cloud_infer_importable(self): + import QEfficient.generation.cloud_infer + + assert QEfficient.generation.cloud_infer is not None + + def test_embedding_handler_importable(self): + import QEfficient.generation.embedding_handler + + assert QEfficient.generation.embedding_handler is not None + + def test_text_generation_inference_importable(self): + import QEfficient.generation.text_generation_inference + + assert QEfficient.generation.text_generation_inference is not None + + def test_vlm_generation_importable(self): + import QEfficient.generation.vlm_generation + + assert QEfficient.generation.vlm_generation is not None + + def test_vision_handler_importable(self): + from QEfficient.generation.embedding_handler import VisionHandler + + assert VisionHandler is not None + + def test_text_generation_class_importable(self): + from QEfficient.generation.text_generation_inference import TextGeneration + + assert TextGeneration is not None + + def test_qeff_text_generation_base_importable(self): + from QEfficient.generation.text_generation_inference import QEffTextGenerationBase + + assert QEffTextGenerationBase is not None + + def test_vision_language_generation_importable(self): + from QEfficient.generation.vlm_generation import VisionLanguageGeneration + + assert VisionLanguageGeneration is not None + + +# --------------------------------------------------------------------------- +# Tests: PerfMetrics dataclass +# --------------------------------------------------------------------------- + + +class TestPerfMetricsDataclass: + def test_construction_and_field_access(self): + m = PerfMetrics(prefill_time=1.5, decode_perf=50.0, total_perf=45.0, total_time=10.0) + assert m.prefill_time == 1.5 + assert m.decode_perf == 50.0 + assert m.total_perf == 45.0 + assert m.total_time == 10.0 + + def test_repr_contains_values(self): + m = PerfMetrics(1.5, 50.0, 45.0, 10.0) + r = repr(m) + assert "1.5" in r or "1.50" in r + + def test_zero_values_allowed(self): + m = PerfMetrics(0.0, 0.0, 0.0, 0.0) + assert m.prefill_time == 0.0 + + +# --------------------------------------------------------------------------- +# Tests: CloudAI100ExecInfo dataclass +# --------------------------------------------------------------------------- + + +class TestCloudAI100ExecInfoDataclass: + def test_construction_and_repr(self): + m = PerfMetrics(1.5, 50.0, 45.0, 10.0) + info = CloudAI100ExecInfo( + batch_size=1, + generated_texts=["Hello"], + generated_ids=[np.array([1, 2, 3])], + perf_metrics=m, + ) + assert info.batch_size == 1 + r = repr(info) + assert "Prefill" in r or "prefill" in r + + def test_nested_list_generated_texts(self): + m = PerfMetrics(1.5, 50.0, 45.0, 10.0) + info = CloudAI100ExecInfo( + batch_size=2, + generated_texts=[["A", "B"], ["C", "D"]], + generated_ids=[np.array([1]), np.array([2])], + perf_metrics=m, + ) + assert len(info.generated_texts) == 2 + + def test_cloud_ai100_exec_info_new(self): + m = PerfMetrics(1.5, 50.0, 45.0, 10.0) + info = CloudAI100ExecInfoNew( + batch_size=1, + generated_ids=[np.array([1, 2, 3])], + perf_metrics=m, + ) + assert info.batch_size == 1 + assert "Prefill" in repr(info) or "prefill" in repr(info) + + +# --------------------------------------------------------------------------- +# Tests: calculate_latency +# --------------------------------------------------------------------------- + + +class TestCalculateLatency: + def test_normal_case(self): + pf, dp, tp, tt = calculate_latency(100, 5.0, 1.0, 11.0, 0.0) + assert pf == pytest.approx(4.0) + assert dp == pytest.approx(100 / 6.0) + assert tp == pytest.approx(100 / 10.0) + assert tt == pytest.approx(10.0) + + def test_with_decode_pause_time(self): + pf, dp, tp, tt = calculate_latency(100, 5.0, 1.0, 11.0, 1.0) + assert pf == pytest.approx(5.0) + assert dp == pytest.approx(100 / 5.0) + + def test_zero_tokens(self): + pf, dp, tp, tt = calculate_latency(0, 5.0, 1.0, 11.0, 0.0) + assert dp == 0.0 + assert tp == 0.0 + + def test_returns_floats(self): + result = calculate_latency(100, 5.0, 1.0, 11.0, 0.0) + assert all(isinstance(v, float) for v in result) + + +# --------------------------------------------------------------------------- +# Tests: get_input_prompts +# --------------------------------------------------------------------------- + + +class TestGetInputPrompts: + def test_both_none_raises(self): + with pytest.raises(ValueError): + get_input_prompts(None, None) + + def test_string_to_list(self): + r = get_input_prompts("Hello", None) + assert r == ["Hello"] + + def test_list_unchanged(self): + r = get_input_prompts(["A", "B"], None) + assert r == ["A", "B"] + + def test_txt_file_priority(self, tmp_path): + f = tmp_path / "p.txt" + f.write_text("L1\nL2\n") + r = get_input_prompts("ignored", str(f)) + assert r == ["L1", "L2"] + + +# --------------------------------------------------------------------------- +# Tests: fix_prompts +# --------------------------------------------------------------------------- + + +class TestFixPrompts: + def test_fewer_prompts_repeated(self): + r = fix_prompts(["A", "B"], 5) + assert len(r) == 5 + assert r == ["A", "B", "A", "B", "A"] + + def test_exact_batch_unchanged(self): + r = fix_prompts(["A", "B", "C"], 3) + assert r == ["A", "B", "C"] + + def test_incomplete_batch_dropped(self): + r = fix_prompts(["A", "B", "C", "D", "E"], 2) + assert len(r) == 4 + + def test_full_batch_size_used(self): + r = fix_prompts(["A", "B"], 3, full_batch_size=8) + assert len(r) == 8 + + def test_single_prompt_repeated(self): + r = fix_prompts(["X"], 4) + assert r == ["X", "X", "X", "X"] + + +# --------------------------------------------------------------------------- +# Tests: fix_prompt_to_lora_id_mapping +# --------------------------------------------------------------------------- + + +class TestFixPromptToLoraIdMapping: + def test_fewer_repeated(self): + r = fix_prompt_to_lora_id_mapping([0, 1], 5) + assert len(r) == 5 + + def test_exact_unchanged(self): + r = fix_prompt_to_lora_id_mapping([0, 1, 2], 3) + assert r == [0, 1, 2] + + def test_full_batch_size(self): + r = fix_prompt_to_lora_id_mapping([0, 1], 3, full_batch_size=8) + assert len(r) == 8 + + +# --------------------------------------------------------------------------- +# Tests: read_prompts_txt_file +# --------------------------------------------------------------------------- + + +class TestReadPromptsTxtFile: + def test_reads_lines(self, tmp_path): + f = tmp_path / "p.txt" + f.write_text("A\nB\nC\n") + assert read_prompts_txt_file(str(f)) == ["A", "B", "C"] + + def test_strips_whitespace(self, tmp_path): + f = tmp_path / "p.txt" + f.write_text(" A \n B \n") + assert read_prompts_txt_file(str(f)) == ["A", "B"] + + def test_empty_file(self, tmp_path): + f = tmp_path / "p.txt" + f.write_text("") + assert read_prompts_txt_file(str(f)) == [] + + def test_missing_file_raises(self): + with pytest.raises(FileNotFoundError): + read_prompts_txt_file("/no/such/file.txt") + + +# --------------------------------------------------------------------------- +# Tests: write_io_files +# --------------------------------------------------------------------------- + + +class TestWriteIoFiles: + def test_creates_json_and_raw_files(self, tmp_path): + inputs = {"input_ids": np.array([[1, 2, 3]], dtype=np.int64)} + outputs = {"logits": np.array([[0.1, 0.2, 0.3]], dtype=np.float32)} + write_io_files(inputs, outputs, str(tmp_path), "sub", "io", reset=True) + assert (tmp_path / "io.json").exists() + assert (tmp_path / "sub" / "input_ids.raw").exists() + assert (tmp_path / "sub" / "logits.raw").exists() + + def test_json_structure(self, tmp_path): + inputs = {"x": np.zeros((1, 4), dtype=np.float32)} + outputs = {"y": np.zeros((1, 4), dtype=np.float32)} + write_io_files(inputs, outputs, str(tmp_path), "s", "io", reset=True) + data = json.loads((tmp_path / "io.json").read_text()) + assert "IO-files" in data + assert len(data["IO-files"]) == 1 + + def test_reset_clears_previous(self, tmp_path): + inputs = {"x": np.zeros((1,), dtype=np.float32)} + outputs = {"y": np.zeros((1,), dtype=np.float32)} + write_io_files(inputs, outputs, str(tmp_path), "s1", "io", reset=True) + write_io_files(inputs, outputs, str(tmp_path), "s2", "io", reset=False) + data = json.loads((tmp_path / "io.json").read_text()) + assert len(data["IO-files"]) == 2 + + def test_include_dims(self, tmp_path): + inputs = {"x": np.zeros((2, 4), dtype=np.float32)} + outputs = {"y": np.zeros((2, 4), dtype=np.float32)} + write_io_files(inputs, outputs, str(tmp_path), "s", "io", include_dims=True, reset=True) + data = json.loads((tmp_path / "io.json").read_text()) + has_dims = any("dims" in e for e in data["IO-files"][0]) + assert has_dims + + +# --------------------------------------------------------------------------- +# Tests: get_compilation_dims +# --------------------------------------------------------------------------- + + +class TestGetCompilationDims: + def _write_spec(self, tmp_path, spec): + qpc_dir = tmp_path / "qpc" + qpc_dir.mkdir() + (qpc_dir / "specializations.json").write_text(json.dumps(spec)) + return str(qpc_dir / "model.qpc") + + def test_basic(self, tmp_path): + path = self._write_spec(tmp_path, {"specializations": [{"batch_size": "4", "ctx_len": "128"}]}) + bs, cl, fbs = get_compilation_dims(path) + assert bs == 4 and cl == 128 and fbs is None + + def test_with_full_batch_size(self, tmp_path): + path = self._write_spec( + tmp_path, {"specializations": [{"batch_size": "4", "ctx_len": "128", "full_batch_size": "16"}]} + ) + bs, cl, fbs = get_compilation_dims(path) + assert fbs == 16 + + def test_missing_file_raises(self, tmp_path): + qpc_dir = tmp_path / "qpc" + qpc_dir.mkdir() + with pytest.raises(FileNotFoundError): + get_compilation_dims(str(qpc_dir / "model.qpc")) + + def test_returns_ints(self, tmp_path): + path = self._write_spec(tmp_path, {"specializations": [{"batch_size": "2", "ctx_len": "64"}]}) + bs, cl, fbs = get_compilation_dims(path) + assert isinstance(bs, int) and isinstance(cl, int) + + +# --------------------------------------------------------------------------- +# Tests: QEffTextGenerationBase construction (mocked session) +# --------------------------------------------------------------------------- + + +class TestQEffTextGenerationBaseConstruction: + """QEffTextGenerationBase must initialise correctly with a mocked session.""" + + def test_construction_succeeds(self): + obj, tok, _ = _make_base_instance() + assert obj is not None + + def test_batch_size_fetched(self): + obj, _, _ = _make_base_instance(batch_size=2) + assert obj.batch_size == 2 + + def test_prefill_seq_len_fetched(self): + obj, _, _ = _make_base_instance() + assert obj._prefill_seq_len == PREFILL_LEN + + def test_ctx_len_stored(self): + obj, _, _ = _make_base_instance(ctx_len=64) + assert obj._ctx_len == 64 + + def test_tokenizer_stored(self): + obj, tok, _ = _make_base_instance() + assert obj.tokenizer is tok + + def test_full_batch_size_none_by_default(self): + obj, _, _ = _make_base_instance() + assert obj.full_batch_size is None + + def test_vocab_size_fetched(self): + obj, _, _ = _make_base_instance() + assert obj._vocab_size == VOCAB_SIZE + + def test_session_skip_buffers_called(self): + obj, _, mock_session = _make_base_instance() + mock_session.skip_buffers.assert_called() + + +# --------------------------------------------------------------------------- +# Tests: initialize_decode_inputs +# --------------------------------------------------------------------------- + + +class TestInitializeDecodeInputs: + """initialize_decode_inputs must allocate correctly shaped numpy arrays.""" + + def test_generated_ids_shape(self): + obj, _, _ = _make_base_instance() + obj.initialize_decode_inputs(num_prompts=2, execution_batch_size=1, max_gen_length=20) + assert obj.generated_ids.shape == (2, 20) + + def test_decode_input_ids_shape(self): + obj, _, _ = _make_base_instance() + obj.initialize_decode_inputs(num_prompts=1, execution_batch_size=1, max_gen_length=10) + assert obj.decode_input_ids.shape == (1, 1) + + def test_decode_pos_ids_shape(self): + obj, _, _ = _make_base_instance() + obj.initialize_decode_inputs(num_prompts=1, execution_batch_size=1, max_gen_length=10) + assert obj.decode_pos_ids.shape == (1, 1) + + def test_generation_len_shape(self): + obj, _, _ = _make_base_instance() + obj.initialize_decode_inputs(num_prompts=1, execution_batch_size=1, max_gen_length=10) + assert obj.generation_len.shape == (1, 1) + + def test_generated_ids_filled_with_pad_token(self): + obj, tok, _ = _make_base_instance() + obj.initialize_decode_inputs(num_prompts=1, execution_batch_size=1, max_gen_length=10) + assert np.all(obj.generated_ids == tok.pad_token_id) + + def test_decode_input_ids_zero_initialized(self): + obj, _, _ = _make_base_instance() + obj.initialize_decode_inputs(num_prompts=1, execution_batch_size=1, max_gen_length=10) + assert np.all(obj.decode_input_ids == 0) + + +# --------------------------------------------------------------------------- +# Tests: prepare_decode_inputs +# --------------------------------------------------------------------------- + + +class TestPrepareDecodeInputs: + """prepare_decode_inputs must build correct decode input dict.""" + + def test_returns_dict_with_input_ids(self): + obj, _, _ = _make_base_instance() + obj.initialize_decode_inputs(1, 1, 10) + decode_inputs = obj.prepare_decode_inputs() + assert "input_ids" in decode_inputs + + def test_returns_dict_with_position_ids(self): + obj, _, _ = _make_base_instance() + obj.initialize_decode_inputs(1, 1, 10) + decode_inputs = obj.prepare_decode_inputs() + assert "position_ids" in decode_inputs + + def test_input_ids_shape_is_batch_by_1(self): + obj, _, _ = _make_base_instance(batch_size=2) + obj.initialize_decode_inputs(2, 2, 10) + decode_inputs = obj.prepare_decode_inputs() + assert decode_inputs["input_ids"].shape == (2, 1) + + def test_position_ids_shape_is_batch_by_1(self): + obj, _, _ = _make_base_instance(batch_size=2) + obj.initialize_decode_inputs(2, 2, 10) + decode_inputs = obj.prepare_decode_inputs() + assert decode_inputs["position_ids"].shape == (2, 1) + + def test_no_batch_index_without_full_batch_size(self): + obj, _, _ = _make_base_instance() + obj.initialize_decode_inputs(1, 1, 10) + decode_inputs = obj.prepare_decode_inputs() + assert "batch_index" not in decode_inputs + + +# --------------------------------------------------------------------------- +# Tests: update_decode_input +# --------------------------------------------------------------------------- + + +class TestUpdateDecodeInput: + """update_decode_input must correctly update decode state arrays.""" + + def _make_outputs(self, token_id=42): + logits = np.zeros((1, 1, VOCAB_SIZE), dtype=np.float32) + logits[0, 0, token_id] = 1.0 + return {"logits": logits} + + def test_decode_input_ids_updated(self): + obj, _, _ = _make_base_instance() + obj.initialize_decode_inputs(1, 1, 10) + outputs = self._make_outputs(token_id=42) + position_ids = np.array([[PREFILL_LEN]]) + obj.update_decode_input(outputs, position_ids, generation_len=10) + assert obj.decode_input_ids[0, 0] == 42 + + def test_decode_pos_ids_updated(self): + obj, _, _ = _make_base_instance() + obj.initialize_decode_inputs(1, 1, 10) + outputs = self._make_outputs(token_id=42) + position_ids = np.array([[PREFILL_LEN]]) + obj.update_decode_input(outputs, position_ids, generation_len=10) + assert obj.decode_pos_ids[0, 0] == PREFILL_LEN + + def test_generated_ids_first_token_set(self): + obj, _, _ = _make_base_instance() + obj.initialize_decode_inputs(1, 1, 10) + outputs = self._make_outputs(token_id=99) + position_ids = np.array([[PREFILL_LEN]]) + obj.update_decode_input(outputs, position_ids, generation_len=10) + assert obj.generated_ids[0, 0] == 99 + + def test_returns_next_token_id(self): + obj, _, _ = _make_base_instance() + obj.initialize_decode_inputs(1, 1, 10) + outputs = self._make_outputs(token_id=77) + position_ids = np.array([[PREFILL_LEN]]) + next_token = obj.update_decode_input(outputs, position_ids, generation_len=10) + assert next_token[0, 0] == 77 + + +# --------------------------------------------------------------------------- +# Tests: run_prefill (mocked session, chunking logic) +# --------------------------------------------------------------------------- + + +class TestRunPrefill: + """run_prefill must tokenize, chunk, and call session.run for each chunk.""" + + def test_run_prefill_returns_outputs_position_ids_generation_len(self): + obj, _, _ = _make_base_instance() + obj.initialize_decode_inputs(1, 1, CTX_LEN) + outputs, position_ids, gen_len = obj.run_prefill( + prompt=["Hello world"], + generation_len=None, + ) + assert outputs is not None + assert position_ids is not None + assert gen_len is not None + + def test_run_prefill_calls_session_run(self): + obj, _, mock_session = _make_base_instance() + obj.initialize_decode_inputs(1, 1, CTX_LEN) + obj.run_prefill(prompt=["Hello world"], generation_len=None) + assert mock_session.run.called + + def test_run_prefill_generation_len_bounded_by_ctx_len(self): + obj, _, _ = _make_base_instance(ctx_len=CTX_LEN) + obj.initialize_decode_inputs(1, 1, CTX_LEN) + _, _, gen_len = obj.run_prefill(prompt=["Hello world"], generation_len=None) + assert gen_len <= CTX_LEN + + def test_run_prefill_generation_len_positive(self): + obj, _, _ = _make_base_instance() + obj.initialize_decode_inputs(1, 1, CTX_LEN) + _, _, gen_len = obj.run_prefill(prompt=["Hello world"], generation_len=None) + assert gen_len > 0 + + def test_run_prefill_chunking_multiple_chunks(self): + """A long prompt that exceeds prefill_seq_len must be split into chunks.""" + obj, tok, mock_session = _make_base_instance() + obj.initialize_decode_inputs(1, 1, CTX_LEN) + # Create a prompt that tokenizes to > PREFILL_LEN tokens + long_prompt = " ".join(["hello"] * 20) + obj.run_prefill(prompt=[long_prompt], generation_len=None) + # session.run must be called at least once (possibly multiple times for chunks) + assert mock_session.run.call_count >= 1 + + def test_run_prefill_with_explicit_generation_len(self): + obj, _, _ = _make_base_instance() + obj.initialize_decode_inputs(1, 1, CTX_LEN) + _, _, gen_len = obj.run_prefill(prompt=["Hello"], generation_len=5) + assert gen_len == 5 + + def test_run_prefill_output_has_logits(self): + obj, _, _ = _make_base_instance() + obj.initialize_decode_inputs(1, 1, CTX_LEN) + outputs, _, _ = obj.run_prefill(prompt=["Hello world"], generation_len=None) + assert "logits" in outputs + + def test_run_prefill_position_ids_shape(self): + obj, _, _ = _make_base_instance() + obj.initialize_decode_inputs(1, 1, CTX_LEN) + _, position_ids, _ = obj.run_prefill(prompt=["Hello world"], generation_len=None) + assert position_ids.shape[0] == 1 # batch dim + + +# --------------------------------------------------------------------------- +# Tests: run_decode (mocked session) +# --------------------------------------------------------------------------- + + +class TestRunDecode: + """run_decode must iterate and update generated_ids correctly.""" + + def _setup_decode(self, generation_len=5): + obj, tok, mock_session = _make_base_instance() + obj.initialize_decode_inputs(1, 1, generation_len + 2) + # Simulate prefill output + outputs = {"logits": np.zeros((1, 1, VOCAB_SIZE), dtype=np.float32)} + outputs["logits"][0, 0, 42] = 1.0 + position_ids = np.array([[PREFILL_LEN]]) + obj.update_decode_input(outputs, position_ids, generation_len=generation_len) + decode_inputs = obj.prepare_decode_inputs() + return obj, tok, mock_session, decode_inputs, generation_len + + def test_run_decode_returns_num_tokens(self): + obj, tok, mock_session, decode_inputs, gen_len = self._setup_decode(5) + num_token = obj.run_decode(decode_inputs, gen_len, automation=True) + assert isinstance(num_token, int) + assert num_token >= 1 + + def test_run_decode_calls_session_run(self): + obj, tok, mock_session, decode_inputs, gen_len = self._setup_decode(3) + obj.run_decode(decode_inputs, gen_len, automation=True) + assert mock_session.run.called + + def test_run_decode_updates_generated_ids(self): + obj, tok, mock_session, decode_inputs, gen_len = self._setup_decode(3) + obj.run_decode(decode_inputs, gen_len, automation=True) + # generated_ids[:, 1:] should have been updated + assert obj.generated_ids[0, 1] == 42 # mock always returns token 42 + + def test_run_decode_stops_at_eos(self): + """Decode must stop early when EOS token is generated.""" + obj, tok, mock_session, decode_inputs, gen_len = self._setup_decode(10) + + # Make session return EOS token + def _run_eos(inputs): + logits = np.zeros((1, 1, VOCAB_SIZE), dtype=np.float32) + logits[0, 0, tok.eos_token_id] = 1.0 + return {"logits": logits} + + mock_session.run.side_effect = _run_eos + num_token = obj.run_decode(decode_inputs, gen_len, automation=False) + # Should stop early (<= generation_len) + assert num_token <= gen_len + + def test_run_decode_position_ids_advance(self): + """position_ids must increase by 1 each decode step.""" + obj, tok, mock_session, decode_inputs, gen_len = self._setup_decode(3) + initial_pos = decode_inputs["position_ids"][0, -1].item() + obj.run_decode(decode_inputs, gen_len, automation=True) + # After decode, position_ids should have advanced + final_pos = decode_inputs["position_ids"][0, -1].item() + assert final_pos > initial_pos + + def test_run_decode_generated_ids_are_valid_tokens(self): + obj, tok, mock_session, decode_inputs, gen_len = self._setup_decode(3) + obj.run_decode(decode_inputs, gen_len, automation=True) + for i in range(1, gen_len): + token = obj.generated_ids[0, i] + if token != tok.pad_token_id: + assert 0 <= token < VOCAB_SIZE + + +# --------------------------------------------------------------------------- +# Tests: generate_decode_stream (mocked session) +# --------------------------------------------------------------------------- + + +class TestGenerateDecodeStream: + """generate_decode_stream must yield token arrays at each step.""" + + def _setup_stream(self, generation_len=4): + obj, tok, mock_session = _make_base_instance() + obj.initialize_decode_inputs(1, 1, generation_len + 2) + outputs = {"logits": np.zeros((1, 1, VOCAB_SIZE), dtype=np.float32)} + outputs["logits"][0, 0, 42] = 1.0 + position_ids = np.array([[PREFILL_LEN]]) + obj.update_decode_input(outputs, position_ids, generation_len=generation_len) + decode_inputs = obj.prepare_decode_inputs() + return obj, tok, mock_session, decode_inputs, generation_len + + def test_yields_token_arrays(self): + obj, tok, mock_session, decode_inputs, gen_len = self._setup_stream(4) + tokens = list(obj.generate_decode_stream(decode_inputs, gen_len, automation=True)) + assert len(tokens) >= 1 + for t in tokens: + assert isinstance(t, np.ndarray) + + def test_yields_correct_shape(self): + obj, tok, mock_session, decode_inputs, gen_len = self._setup_stream(4) + tokens = list(obj.generate_decode_stream(decode_inputs, gen_len, automation=True)) + for t in tokens: + assert t.shape[0] == 1 # batch dim + + def test_yields_at_most_generation_len_tokens(self): + obj, tok, mock_session, decode_inputs, gen_len = self._setup_stream(4) + tokens = list(obj.generate_decode_stream(decode_inputs, gen_len, automation=True)) + assert len(tokens) <= gen_len + 1 # +1 for final yield + + def test_stops_at_eos(self): + obj, tok, mock_session, decode_inputs, gen_len = self._setup_stream(10) + + def _run_eos(inputs): + logits = np.zeros((1, 1, VOCAB_SIZE), dtype=np.float32) + logits[0, 0, tok.eos_token_id] = 1.0 + return {"logits": logits} + + mock_session.run.side_effect = _run_eos + tokens = list(obj.generate_decode_stream(decode_inputs, gen_len, automation=False)) + assert len(tokens) <= gen_len + 1 + + +# --------------------------------------------------------------------------- +# Tests: Chunking logic in prefill +# --------------------------------------------------------------------------- + + +class TestPrefillChunking: + """Prefill must correctly chunk long prompts into prefill_seq_len pieces.""" + + def test_single_chunk_for_short_prompt(self): + obj, _, mock_session = _make_base_instance() + obj.initialize_decode_inputs(1, 1, CTX_LEN) + mock_session.run.reset_mock() + # Short prompt: should fit in one chunk + obj.run_prefill(prompt=["Hi"], generation_len=None) + assert mock_session.run.call_count == 1 + + def test_multiple_chunks_for_long_prompt(self): + """A prompt tokenizing to > prefill_seq_len must produce multiple chunks.""" + obj, tok, mock_session = _make_base_instance() + obj.initialize_decode_inputs(1, 1, CTX_LEN) + mock_session.run.reset_mock() + # Force a prompt that tokenizes to > PREFILL_LEN tokens + # by using a very long string + long_prompt = "hello " * 30 # ~30 tokens + obj.run_prefill(prompt=[long_prompt], generation_len=None) + # With prefill_seq_len=8, 30 tokens → ceil(30/8) = 4 chunks + assert mock_session.run.call_count >= 2 + + def test_chunk_inputs_have_correct_seq_len(self): + """Each chunk passed to session.run must have seq_len == prefill_seq_len.""" + obj, _, mock_session = _make_base_instance() + obj.initialize_decode_inputs(1, 1, CTX_LEN) + mock_session.run.reset_mock() + long_prompt = "hello " * 30 + obj.run_prefill(prompt=[long_prompt], generation_len=None) + for call in mock_session.run.call_args_list: + chunk_inputs = call[0][0] + assert chunk_inputs["input_ids"].shape[1] == PREFILL_LEN + + def test_position_ids_in_chunks_are_sequential(self): + """position_ids in each chunk must be sequential (or -1 for padding).""" + obj, _, mock_session = _make_base_instance() + obj.initialize_decode_inputs(1, 1, CTX_LEN) + mock_session.run.reset_mock() + long_prompt = "hello " * 20 + obj.run_prefill(prompt=[long_prompt], generation_len=None) + for call in mock_session.run.call_args_list: + chunk_inputs = call[0][0] + pos = chunk_inputs["position_ids"][0] + valid_pos = pos[pos >= 0] + if len(valid_pos) > 1: + diffs = np.diff(valid_pos) + assert np.all(diffs == 1), f"Non-sequential position_ids: {valid_pos}" + + +# --------------------------------------------------------------------------- +# Tests: Continuous batching (mocked session with full_batch_size) +# --------------------------------------------------------------------------- + + +class TestContinuousBatching: + """run_continuous_batching_decode must handle the CB decode loop correctly.""" + + def _make_cb_instance(self, full_batch_size=2): + from QEfficient.generation.text_generation_inference import QEffTextGenerationBase + + tok = _make_tokenizer() + # For CB prefill, run_prefill expects to read next token from logits. + # We force seq_len=1 so update_decode_input can store into (full_batch_size, 1). + mock_session = _make_mock_session( + batch_size=full_batch_size, + prefill_seq_len=PREFILL_LEN, + ctx_len=CTX_LEN, + vocab_size=VOCAB_SIZE, + full_batch_size=full_batch_size, + force_seq_len=1, + ) + + # Add batch_index to binding_index_map + bi_binding = MagicMock() + bi_binding.name = "batch_index" + bi_binding.dims = [full_batch_size, 1] + bi_binding.dir = "input" + bi_binding.size = full_batch_size * 4 + bi_binding.type = 1 + mock_session.bindings.append(bi_binding) + mock_session.binding_index_map["batch_index"] = len(mock_session.bindings) - 1 + mock_session.input_names.append("batch_index") + + # allowed_shapes for full_batch_size detection + mock_session.allowed_shapes = [ + [ + (4, [full_batch_size, PREFILL_LEN]), # input_ids + (4, [full_batch_size, PREFILL_LEN]), # position_ids + (4, [full_batch_size, PREFILL_LEN, VOCAB_SIZE]), # logits + (4, [full_batch_size, 1]), # batch_index + ], + [ + (4, [full_batch_size, 1]), # input_ids decode + (4, [full_batch_size, 1]), # position_ids decode + (4, [full_batch_size, 1, VOCAB_SIZE]), # logits decode + (4, [full_batch_size, 1]), # batch_index + ], + ] + + with patch( + "QEfficient.generation.text_generation_inference.QAICInferenceSession", + return_value=mock_session, + ): + obj = QEffTextGenerationBase( + tokenizer=tok, + qpc_path="/fake/path/model.qpc", + ctx_len=CTX_LEN, + full_batch_size=full_batch_size, + ) + return obj, tok, mock_session + + def test_cb_instance_has_full_batch_size(self): + obj, _, _ = self._make_cb_instance(full_batch_size=2) + assert obj.full_batch_size == 2 + + def test_initialize_decode_inputs_with_full_batch_size(self): + obj, _, _ = self._make_cb_instance(full_batch_size=2) + obj.initialize_decode_inputs( + num_prompts=4, + execution_batch_size=2, + max_gen_length=10, + ) + assert obj.generated_ids.shape == (4, 10) + assert obj.decode_input_ids.shape == (2, 1) + + def test_prepare_decode_inputs_with_batch_index(self): + obj, _, _ = self._make_cb_instance(full_batch_size=2) + obj.initialize_decode_inputs(2, 2, 10) + obj.batch_index = np.arange(2).reshape(-1, 1) + decode_inputs = obj.prepare_decode_inputs() + assert "batch_index" in decode_inputs + + def test_run_prefill_for_all_inputs_calls_session(self): + obj, tok, mock_session = self._make_cb_instance(full_batch_size=2) + obj.initialize_decode_inputs(2, 2, CTX_LEN) + mock_session.run.reset_mock() + prompt_queue = deque(["Hello", "World"]) + obj.run_prefill_for_all_inputs(prompt_queue, generation_len=None) + assert mock_session.run.called + + def test_run_prefill_for_all_inputs_empties_queue(self): + obj, tok, mock_session = self._make_cb_instance(full_batch_size=2) + obj.initialize_decode_inputs(2, 2, CTX_LEN) + prompt_queue = deque(["Hello", "World"]) + obj.run_prefill_for_all_inputs(prompt_queue, generation_len=None) + assert len(prompt_queue) == 0 + + +# --------------------------------------------------------------------------- +# Tests: _fetch_next_token_id +# --------------------------------------------------------------------------- + + +class TestFetchNextTokenId: + """_fetch_next_token_id must extract argmax from logits correctly.""" + + def test_returns_argmax_of_logits(self): + obj, _, _ = _make_base_instance() + logits = np.zeros((1, 1, VOCAB_SIZE), dtype=np.float32) + logits[0, 0, 77] = 1.0 + outputs = {"logits": logits} + token = obj._fetch_next_token_id(outputs) + assert token[0, 0] == 77 + + def test_batch_argmax(self): + obj, _, _ = _make_base_instance(batch_size=2) + logits = np.zeros((2, 1, VOCAB_SIZE), dtype=np.float32) + logits[0, 0, 10] = 1.0 + logits[1, 0, 20] = 1.0 + outputs = {"logits": logits} + tokens = obj._fetch_next_token_id(outputs) + assert tokens[0, 0] == 10 + assert tokens[1, 0] == 20 + + def test_2d_logits_expanded(self): + """2D logits (batch, vocab) must be expanded to (batch, 1, vocab).""" + obj, _, _ = _make_base_instance() + logits = np.zeros((1, VOCAB_SIZE), dtype=np.float32) + logits[0, 55] = 1.0 + outputs = {"logits": logits} + token = obj._fetch_next_token_id(outputs) + assert token[0, 0] == 55 + + +# --------------------------------------------------------------------------- +# Tests: _set_output_buffers +# --------------------------------------------------------------------------- + + +class TestSetOutputBuffers: + """_set_output_buffers must call session.set_buffers with correct shapes.""" + + def test_set_output_buffers_calls_set_buffers(self): + obj, _, mock_session = _make_base_instance() + mock_session.set_buffers.reset_mock() + obj._set_output_buffers(batch_size=1, sequence_length=1) + mock_session.set_buffers.assert_called_once() + + def test_set_output_buffers_logits_shape(self): + obj, _, mock_session = _make_base_instance() + mock_session.set_buffers.reset_mock() + obj._set_output_buffers(batch_size=2, sequence_length=4) + call_args = mock_session.set_buffers.call_args[0][0] + assert "logits" in call_args + assert call_args["logits"].shape == (2, 4, VOCAB_SIZE) + + def test_set_output_buffers_dtype_float32(self): + obj, _, mock_session = _make_base_instance() + mock_session.set_buffers.reset_mock() + obj._set_output_buffers(batch_size=1, sequence_length=1) + call_args = mock_session.set_buffers.call_args[0][0] + assert call_args["logits"].dtype == np.float32 + + +# --------------------------------------------------------------------------- +# Tests: VisionHandler initialization (CPU-only) +# --------------------------------------------------------------------------- + + +class TestVisionHandlerInit: + """VisionHandler must initialize correctly with None sessions.""" + + def test_construction_with_none_sessions(self): + from QEfficient.generation.embedding_handler import VisionHandler + + h = VisionHandler(qeff_model=None, vision_session=None, processor=None, tokenizer=None) + assert h is not None + + def test_is_available_false_with_none(self): + from QEfficient.generation.embedding_handler import VisionHandler + + h = VisionHandler(qeff_model=None, vision_session=None, processor=None, tokenizer=None) + assert h.is_available() is False + + def test_is_available_false_session_no_processor(self): + from QEfficient.generation.embedding_handler import VisionHandler + + h = VisionHandler(qeff_model=None, vision_session=MagicMock(), processor=None, tokenizer=None) + assert h.is_available() is False + + def test_get_vision_output_shapes_default(self): + from QEfficient.generation.embedding_handler import VisionHandler + + h = VisionHandler(qeff_model=None, vision_session=None, processor=None, tokenizer=None) + shapes = h.get_vision_output_shapes() + assert isinstance(shapes, dict) + assert "vision_embeds" in shapes + + def test_get_vision_output_shapes_from_config(self): + from QEfficient.generation.embedding_handler import VisionHandler + + config = {"vision_output_shapes": {"my_out": (100, 200)}} + h = VisionHandler(qeff_model=None, vision_session=None, processor=None, tokenizer=None, config=config) + shapes = h.get_vision_output_shapes() + assert shapes["my_out"] == (100, 200) + + def test_image_dims_stored(self): + from QEfficient.generation.embedding_handler import VisionHandler + + h = VisionHandler( + qeff_model=None, vision_session=None, processor=None, tokenizer=None, image_height=224, image_width=224 + ) + assert h._image_height == 224 and h._image_width == 224 + + def test_setup_vision_buffers_raises_without_session(self): + from QEfficient.generation.embedding_handler import VisionHandler + + h = VisionHandler(qeff_model=None, vision_session=None, processor=None, tokenizer=None) + with pytest.raises(ValueError): + h.setup_vision_buffers() + + def test_run_vision_inference_raises_without_session(self): + from QEfficient.generation.embedding_handler import VisionHandler + + h = VisionHandler(qeff_model=None, vision_session=None, processor=None, tokenizer=None) + with pytest.raises(ValueError): + h.run_vision_inference({}) + + def test_prepare_vlm_inputs_raises_without_processor(self): + from QEfficient.generation.embedding_handler import VisionHandler + + h = VisionHandler(qeff_model=None, vision_session=None, processor=None, tokenizer=None) + with pytest.raises((ValueError, AttributeError)): + h.prepare_vlm_inputs("image.jpg", "query", 128) diff --git a/tests/unit_test/utils/test_input_handler.py b/tests/unit_test/utils/test_input_handler.py new file mode 100644 index 000000000..ef964529b --- /dev/null +++ b/tests/unit_test/utils/test_input_handler.py @@ -0,0 +1,409 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +# SPDX-License-Identifier: BSD-3-Clause +# +# ----------------------------------------------------------------------------- +""" +Tests for InputHandler: prepare_pytorch_inputs, update_pytorch_inputs, +prepare_ort_inputs, update_ort_inputs, update_ort_outputs. + +All tests run on CPU only. Tests that require a tokenizer download are +automatically skipped if the network is unavailable. +""" + +import numpy as np +import pytest +import torch +from transformers import GPT2Config, GPT2LMHeadModel + +from QEfficient.utils.generate_inputs import InputHandler + +CTX_LEN = 32 +VOCAB_SIZE = 500 + + +def _get_tokenizer(): + try: + from transformers import AutoTokenizer + + tok = AutoTokenizer.from_pretrained("gpt2") + tok.pad_token = tok.eos_token + return tok + except Exception: + pytest.skip("Cannot load gpt2 tokenizer (network unavailable)") + + +def _make_tiny_gpt2_config(tokenizer): + return GPT2Config( + n_layer=2, + n_head=2, + n_embd=64, + vocab_size=tokenizer.vocab_size, + n_positions=CTX_LEN, + n_ctx=CTX_LEN, + ) + + +def _make_handler(tokenizer, config, prompt=None, prompt_len=8, ctx_len=CTX_LEN): + if prompt is None: + prompt = ["Hello world"] + return InputHandler( + batch_size=1, + tokenizer=tokenizer, + config=config, + prompt=prompt, + prompt_len=prompt_len, + ctx_len=ctx_len, + full_batch_size=None, + ) + + +class TestInputHandlerConstruction: + def test_construction_succeeds(self): + tok = _get_tokenizer() + cfg = _make_tiny_gpt2_config(tok) + handler = _make_handler(tok, cfg) + assert handler is not None + + def test_construction_with_multiple_prompts(self): + tok = _get_tokenizer() + cfg = _make_tiny_gpt2_config(tok) + handler = InputHandler( + batch_size=2, + tokenizer=tok, + config=cfg, + prompt=["Hello world", "The capital of France"], + prompt_len=8, + ctx_len=CTX_LEN, + full_batch_size=None, + ) + assert handler is not None + + def test_construction_with_longer_ctx_len(self): + tok = _get_tokenizer() + cfg = _make_tiny_gpt2_config(tok) + handler = _make_handler(tok, cfg, ctx_len=64) + assert handler is not None + + +class TestPreparePytorchInputs: + def test_returns_dict(self): + tok = _get_tokenizer() + cfg = _make_tiny_gpt2_config(tok) + inputs = _make_handler(tok, cfg).prepare_pytorch_inputs() + assert hasattr(inputs, "__getitem__") and hasattr(inputs, "keys") + + def test_has_input_ids(self): + tok = _get_tokenizer() + cfg = _make_tiny_gpt2_config(tok) + inputs = _make_handler(tok, cfg, prompt_len=8).prepare_pytorch_inputs() + assert "input_ids" in inputs + + def test_has_position_ids(self): + tok = _get_tokenizer() + cfg = _make_tiny_gpt2_config(tok) + inputs = _make_handler(tok, cfg, prompt_len=8).prepare_pytorch_inputs() + assert "position_ids" in inputs + + def test_has_past_key_values(self): + tok = _get_tokenizer() + cfg = _make_tiny_gpt2_config(tok) + inputs = _make_handler(tok, cfg, prompt_len=8).prepare_pytorch_inputs() + assert "past_key_values" in inputs + + def test_input_ids_shape(self): + tok = _get_tokenizer() + cfg = _make_tiny_gpt2_config(tok) + prompt_len = 8 + inputs = _make_handler(tok, cfg, prompt_len=prompt_len).prepare_pytorch_inputs() + assert inputs["input_ids"].shape[0] == 1 + assert inputs["input_ids"].shape[1] == prompt_len + + def test_position_ids_shape(self): + tok = _get_tokenizer() + cfg = _make_tiny_gpt2_config(tok) + prompt_len = 8 + inputs = _make_handler(tok, cfg, prompt_len=prompt_len).prepare_pytorch_inputs() + assert inputs["position_ids"].shape == (1, prompt_len) + + def test_position_ids_are_sequential(self): + tok = _get_tokenizer() + cfg = _make_tiny_gpt2_config(tok) + inputs = _make_handler(tok, cfg, prompt_len=8).prepare_pytorch_inputs() + pos = inputs["position_ids"].squeeze() + valid_pos = pos[pos >= 0] + assert len(valid_pos) > 0 + if len(valid_pos) > 1: + diffs = valid_pos[1:] - valid_pos[:-1] + assert (diffs > 0).all(), f"Position IDs are not strictly increasing: {valid_pos}" + + def test_past_key_values_has_correct_number_of_layers(self): + tok = _get_tokenizer() + cfg = _make_tiny_gpt2_config(tok) + inputs = _make_handler(tok, cfg).prepare_pytorch_inputs() + assert len(inputs["past_key_values"]) == cfg.n_layer + + def test_past_key_values_are_zero_initialized(self): + tok = _get_tokenizer() + cfg = _make_tiny_gpt2_config(tok) + inputs = _make_handler(tok, cfg).prepare_pytorch_inputs() + for layer_idx, (k, v) in enumerate(inputs["past_key_values"]): + assert torch.all(k == 0), f"Layer {layer_idx} key cache is not zero-initialized" + assert torch.all(v == 0), f"Layer {layer_idx} value cache is not zero-initialized" + + def test_past_key_values_ctx_len_dimension(self): + tok = _get_tokenizer() + cfg = _make_tiny_gpt2_config(tok) + inputs = _make_handler(tok, cfg, ctx_len=CTX_LEN).prepare_pytorch_inputs() + for layer_idx, (k, v) in enumerate(inputs["past_key_values"]): + assert k.shape[2] == CTX_LEN, f"Layer {layer_idx} key cache ctx_len={k.shape[2]}, expected {CTX_LEN}" + assert v.shape[2] == CTX_LEN, f"Layer {layer_idx} value cache ctx_len={v.shape[2]}, expected {CTX_LEN}" + + def test_input_ids_are_valid_token_ids(self): + tok = _get_tokenizer() + cfg = _make_tiny_gpt2_config(tok) + inputs = _make_handler(tok, cfg).prepare_pytorch_inputs() + ids = inputs["input_ids"] + assert (ids >= 0).all(), "Negative token IDs found" + assert (ids < tok.vocab_size).all(), "Token IDs exceed vocab_size" + + +class TestUpdatePytorchInputs: + def _run_prefill(self, tok, cfg, prompt_len=8): + from QEfficient.transformers.models.modeling_auto import QEFFAutoModelForCausalLM + + model = GPT2LMHeadModel(cfg).eval() + qeff_model = QEFFAutoModelForCausalLM(model) + handler = _make_handler(tok, cfg, prompt_len=prompt_len) + inputs = handler.prepare_pytorch_inputs() + with torch.no_grad(): + outputs = qeff_model.model(**inputs) + return handler, inputs, outputs + + def test_update_returns_dict(self): + tok = _get_tokenizer() + cfg = _make_tiny_gpt2_config(tok) + handler, inputs, outputs = self._run_prefill(tok, cfg) + updated = handler.update_pytorch_inputs(inputs, outputs) + assert hasattr(updated, "__getitem__") and hasattr(updated, "keys") + + def test_update_has_input_ids(self): + tok = _get_tokenizer() + cfg = _make_tiny_gpt2_config(tok) + handler, inputs, outputs = self._run_prefill(tok, cfg) + updated = handler.update_pytorch_inputs(inputs, outputs) + assert "input_ids" in updated + + def test_update_has_position_ids(self): + tok = _get_tokenizer() + cfg = _make_tiny_gpt2_config(tok) + handler, inputs, outputs = self._run_prefill(tok, cfg) + updated = handler.update_pytorch_inputs(inputs, outputs) + assert "position_ids" in updated + + def test_update_input_ids_is_single_token(self): + tok = _get_tokenizer() + cfg = _make_tiny_gpt2_config(tok) + handler, inputs, outputs = self._run_prefill(tok, cfg) + updated = handler.update_pytorch_inputs(inputs, outputs) + assert updated["input_ids"].shape == (1, 1), ( + f"Decode input_ids must be shape (1,1), got {updated['input_ids'].shape}" + ) + + def test_update_position_ids_advances(self): + tok = _get_tokenizer() + cfg = _make_tiny_gpt2_config(tok) + prompt_len = 8 + handler, inputs, outputs = self._run_prefill(tok, cfg, prompt_len=prompt_len) + updated = handler.update_pytorch_inputs(inputs, outputs) + decode_pos = updated["position_ids"].item() + prefill_last_valid = inputs["position_ids"][inputs["position_ids"] >= 0].max().item() + assert decode_pos > prefill_last_valid, ( + f"Decode position {decode_pos} must be > last prefill position {prefill_last_valid}" + ) + + def test_update_next_token_is_valid(self): + tok = _get_tokenizer() + cfg = _make_tiny_gpt2_config(tok) + handler, inputs, outputs = self._run_prefill(tok, cfg) + updated = handler.update_pytorch_inputs(inputs, outputs) + next_token = updated["input_ids"].item() + assert 0 <= next_token < tok.vocab_size, ( + f"Next token {next_token} is not a valid token ID (vocab_size={tok.vocab_size})" + ) + + +class TestPrepareOrtInputs: + def test_returns_dict_like(self): + tok = _get_tokenizer() + cfg = _make_tiny_gpt2_config(tok) + ort_inputs = _make_handler(tok, cfg).prepare_ort_inputs() + assert hasattr(ort_inputs, "__getitem__") and hasattr(ort_inputs, "keys") + + def test_has_input_ids(self): + tok = _get_tokenizer() + cfg = _make_tiny_gpt2_config(tok) + ort_inputs = dict(_make_handler(tok, cfg).prepare_ort_inputs()) + assert "input_ids" in ort_inputs + + def test_has_position_ids(self): + tok = _get_tokenizer() + cfg = _make_tiny_gpt2_config(tok) + ort_inputs = dict(_make_handler(tok, cfg).prepare_ort_inputs()) + assert "position_ids" in ort_inputs + + def test_has_past_key_value_inputs(self): + tok = _get_tokenizer() + cfg = _make_tiny_gpt2_config(tok) + ort_inputs = dict(_make_handler(tok, cfg).prepare_ort_inputs()) + has_past = any("past_key" in k or "past_value" in k for k in ort_inputs.keys()) + assert has_past, f"No past_key/past_value inputs found: {list(ort_inputs.keys())}" + + def test_input_ids_are_numpy_int64(self): + tok = _get_tokenizer() + cfg = _make_tiny_gpt2_config(tok) + ort_inputs = dict(_make_handler(tok, cfg).prepare_ort_inputs()) + ids = ort_inputs["input_ids"] + assert isinstance(ids, np.ndarray), f"input_ids must be numpy array, got {type(ids)}" + assert ids.dtype == np.int64, f"input_ids must be int64, got {ids.dtype}" + + def test_position_ids_are_numpy_int64(self): + tok = _get_tokenizer() + cfg = _make_tiny_gpt2_config(tok) + ort_inputs = dict(_make_handler(tok, cfg).prepare_ort_inputs()) + pos = ort_inputs["position_ids"] + assert isinstance(pos, np.ndarray) + assert pos.dtype == np.int64 + + def test_past_key_values_are_numpy_float32(self): + tok = _get_tokenizer() + cfg = _make_tiny_gpt2_config(tok) + ort_inputs = dict(_make_handler(tok, cfg).prepare_ort_inputs()) + for key, val in ort_inputs.items(): + if "past_key" in key or "past_value" in key: + assert isinstance(val, np.ndarray) + assert val.dtype == np.float32, f"{key} must be float32, got {val.dtype}" + + def test_past_key_values_are_zero_initialized(self): + tok = _get_tokenizer() + cfg = _make_tiny_gpt2_config(tok) + ort_inputs = dict(_make_handler(tok, cfg).prepare_ort_inputs()) + for key, val in ort_inputs.items(): + if "past_key" in key or "past_value" in key: + assert np.all(val == 0), f"{key} must be zero-initialized for prefill" + + def test_past_key_values_ctx_len_dimension(self): + tok = _get_tokenizer() + cfg = _make_tiny_gpt2_config(tok) + ort_inputs = dict(_make_handler(tok, cfg, ctx_len=CTX_LEN).prepare_ort_inputs()) + for key, val in ort_inputs.items(): + if "past_key" in key or "past_value" in key: + assert val.shape[2] == CTX_LEN, f"{key} ctx_len={val.shape[2]}, expected {CTX_LEN}" + + def test_correct_number_of_kv_cache_inputs(self): + tok = _get_tokenizer() + cfg = _make_tiny_gpt2_config(tok) + ort_inputs = dict(_make_handler(tok, cfg).prepare_ort_inputs()) + past_keys = [k for k in ort_inputs if "past_key" in k] + past_values = [k for k in ort_inputs if "past_value" in k] + assert len(past_keys) == cfg.n_layer + assert len(past_values) == cfg.n_layer + + def test_pytorch_and_ort_inputs_have_same_keys(self): + tok = _get_tokenizer() + cfg = _make_tiny_gpt2_config(tok) + handler = _make_handler(tok, cfg) + pt_inputs = handler.prepare_pytorch_inputs() + ort_inputs = dict(handler.prepare_ort_inputs()) + assert "input_ids" in pt_inputs and "input_ids" in ort_inputs + assert "position_ids" in pt_inputs and "position_ids" in ort_inputs + + +class TestUpdateOrtInputsOutputs: + def _make_fake_ort_outputs(self, cfg, prompt_len=8): + n_layers = cfg.n_layer + n_heads = cfg.n_head + head_dim = cfg.n_embd // n_heads + outputs = { + "logits": np.random.randn(1, prompt_len, cfg.vocab_size).astype(np.float32), + } + for i in range(n_layers): + outputs[f"past_key.{i}_RetainedState"] = np.zeros((1, n_heads, CTX_LEN, head_dim), dtype=np.float32) + outputs[f"past_value.{i}_RetainedState"] = np.zeros((1, n_heads, CTX_LEN, head_dim), dtype=np.float32) + return outputs + + def test_update_ort_outputs_returns_dict(self): + tok = _get_tokenizer() + cfg = _make_tiny_gpt2_config(tok) + handler = _make_handler(tok, cfg) + result = handler.update_ort_outputs(self._make_fake_ort_outputs(cfg)) + assert hasattr(result, "__getitem__") and hasattr(result, "keys") + + def test_update_ort_outputs_has_logits(self): + tok = _get_tokenizer() + cfg = _make_tiny_gpt2_config(tok) + handler = _make_handler(tok, cfg) + result = handler.update_ort_outputs(self._make_fake_ort_outputs(cfg)) + assert "logits" in result + + def test_update_ort_inputs_returns_dict(self): + tok = _get_tokenizer() + cfg = _make_tiny_gpt2_config(tok) + handler = _make_handler(tok, cfg) + ort_inputs = dict(handler.prepare_ort_inputs()) + processed = handler.update_ort_outputs(self._make_fake_ort_outputs(cfg)) + updated = handler.update_ort_inputs(ort_inputs, processed) + assert hasattr(updated, "__getitem__") and hasattr(updated, "keys") + + def test_update_ort_inputs_has_input_ids(self): + tok = _get_tokenizer() + cfg = _make_tiny_gpt2_config(tok) + handler = _make_handler(tok, cfg) + ort_inputs = dict(handler.prepare_ort_inputs()) + processed = handler.update_ort_outputs(self._make_fake_ort_outputs(cfg)) + updated = handler.update_ort_inputs(ort_inputs, processed) + assert "input_ids" in updated + + def test_update_ort_inputs_has_position_ids(self): + tok = _get_tokenizer() + cfg = _make_tiny_gpt2_config(tok) + handler = _make_handler(tok, cfg) + ort_inputs = dict(handler.prepare_ort_inputs()) + processed = handler.update_ort_outputs(self._make_fake_ort_outputs(cfg)) + updated = handler.update_ort_inputs(ort_inputs, processed) + assert "position_ids" in updated + + def test_update_ort_inputs_input_ids_batch_size_is_1(self): + tok = _get_tokenizer() + cfg = _make_tiny_gpt2_config(tok) + handler = _make_handler(tok, cfg, prompt_len=8) + ort_inputs = dict(handler.prepare_ort_inputs()) + processed = handler.update_ort_outputs(self._make_fake_ort_outputs(cfg, prompt_len=8)) + updated = handler.update_ort_inputs(ort_inputs, processed) + assert updated["input_ids"].shape[0] == 1 + assert isinstance(updated["input_ids"], np.ndarray) + + def test_update_ort_inputs_position_ids_advances(self): + tok = _get_tokenizer() + cfg = _make_tiny_gpt2_config(tok) + prompt_len = 8 + handler = _make_handler(tok, cfg, prompt_len=prompt_len) + ort_inputs = dict(handler.prepare_ort_inputs()) + processed = handler.update_ort_outputs(self._make_fake_ort_outputs(cfg, prompt_len=prompt_len)) + updated = handler.update_ort_inputs(ort_inputs, processed) + decode_pos = updated["position_ids"].flatten()[0] + prefill_last_valid = ort_inputs["position_ids"][ort_inputs["position_ids"] >= 0].max() + assert decode_pos > prefill_last_valid, ( + f"Decode position {decode_pos} must be > last prefill position {prefill_last_valid}" + ) + + def test_update_ort_inputs_are_numpy_arrays(self): + tok = _get_tokenizer() + cfg = _make_tiny_gpt2_config(tok) + handler = _make_handler(tok, cfg) + ort_inputs = dict(handler.prepare_ort_inputs()) + processed = handler.update_ort_outputs(self._make_fake_ort_outputs(cfg)) + updated = handler.update_ort_inputs(ort_inputs, processed) + for key, val in updated.items(): + assert isinstance(val, np.ndarray), f"ORT input '{key}' must be numpy array, got {type(val)}" diff --git a/tests/unit_test/utils/test_modeling_registry.py b/tests/unit_test/utils/test_modeling_registry.py new file mode 100644 index 000000000..0c432b4ae --- /dev/null +++ b/tests/unit_test/utils/test_modeling_registry.py @@ -0,0 +1,722 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +# SPDX-License-Identifier: BSD-3-Clause +# +# ----------------------------------------------------------------------------- +""" +Tests for modeling utilities, supported architectures, and model registry. + +Improvements over unit_v2: + - Expanded architecture coverage: Phi3, Gemma, Gemma2, Falcon, Mixtral, Qwen3 + - Expanded MODEL_CLASS_MAPPING coverage + - Tests for DYNAMIC_SEQ_LEN_SUPPORTED_MODEL_ARCH + - Tests for _create_causal_mask numerical correctness + - Tests for build_model_class_mapping + - Tests for QEFFAutoModelForCausalLM class structure including continuous_batching + +All tests run on CPU only, no model loading required. +""" + +import pytest +import torch + +from QEfficient.transformers.modeling_utils import ( + DYNAMIC_SEQ_LEN_SUPPORTED_MODEL_ARCH, + MODEL_CLASS_MAPPING, + TransformersToQEffModulesDict, + _create_causal_mask, + build_model_class_mapping, + qeff_supported_architectures, +) +from QEfficient.transformers.models.modeling_auto import ( + QEFFAutoModel, + QEFFAutoModelForCausalLM, + QEFFAutoModelForSequenceClassification, + QEFFAutoModelForSpeechSeq2Seq, +) + +# --------------------------------------------------------------------------- +# Tests: qeff_supported_architectures +# --------------------------------------------------------------------------- + + +class TestQEffSupportedArchitectures: + """qeff_supported_architectures must contain all expected model names.""" + + def test_is_not_empty(self): + assert len(qeff_supported_architectures.architectures) > 0 + + def test_contains_gpt2(self): + assert "GPT2LMHeadModel" in qeff_supported_architectures.architectures + + def test_contains_llama(self): + assert "LlamaForCausalLM" in qeff_supported_architectures.architectures + + def test_contains_mistral(self): + assert "MistralForCausalLM" in qeff_supported_architectures.architectures + + def test_contains_mixtral(self): + assert "MixtralForCausalLM" in qeff_supported_architectures.architectures + + def test_contains_phi3(self): + assert "Phi3ForCausalLM" in qeff_supported_architectures.architectures + + def test_contains_falcon(self): + assert "FalconForCausalLM" in qeff_supported_architectures.architectures + + def test_contains_qwen2(self): + assert "Qwen2ForCausalLM" in qeff_supported_architectures.architectures + + def test_contains_gemma(self): + assert "GemmaForCausalLM" in qeff_supported_architectures.architectures + + def test_contains_gemma2(self): + assert "Gemma2ForCausalLM" in qeff_supported_architectures.architectures + + def test_contains_whisper(self): + assert "WhisperForConditionalGeneration" in qeff_supported_architectures.architectures + + def test_contains_mllama(self): + assert "MllamaForCausalLM" in qeff_supported_architectures.architectures + + def test_contains_starcoder2(self): + assert "Starcoder2ForCausalLM" in qeff_supported_architectures.architectures + + def test_contains_gptj(self): + assert "GPTJForCausalLM" in qeff_supported_architectures.architectures + + def test_all_entries_are_strings(self): + for arch in qeff_supported_architectures.architectures: + assert isinstance(arch, str), f"Expected string, got {type(arch)}: {arch}" + + def test_no_duplicates(self): + archs = qeff_supported_architectures.architectures + assert len(archs) == len(set(archs)), "Duplicate entries in supported architectures" + + +# --------------------------------------------------------------------------- +# Tests: TransformersToQEffModulesDict +# --------------------------------------------------------------------------- + + +class TestTransformersToQEffModulesDict: + """TransformersToQEffModulesDict must map HF classes to QEff classes correctly.""" + + def test_is_not_empty(self): + assert len(TransformersToQEffModulesDict) > 0 + + def test_gpt2_maps_to_qeff_gpt2(self): + from transformers.models.gpt2.modeling_gpt2 import GPT2LMHeadModel + + from QEfficient.transformers.models.gpt2.modeling_gpt2 import QEffGPT2LMHeadModel + + assert GPT2LMHeadModel in TransformersToQEffModulesDict + assert TransformersToQEffModulesDict[GPT2LMHeadModel] is QEffGPT2LMHeadModel + + def test_llama_maps_to_qeff_llama(self): + from transformers.models.llama.modeling_llama import LlamaForCausalLM + + from QEfficient.transformers.models.llama.modeling_llama import QEffLlamaForCausalLM + + assert LlamaForCausalLM in TransformersToQEffModulesDict + assert TransformersToQEffModulesDict[LlamaForCausalLM] is QEffLlamaForCausalLM + + def test_mistral_maps_to_qeff_mistral(self): + from transformers.models.mistral.modeling_mistral import MistralForCausalLM + + assert MistralForCausalLM in TransformersToQEffModulesDict + + def test_mixtral_maps_to_qeff_mixtral(self): + from transformers.models.mixtral.modeling_mixtral import MixtralForCausalLM + + assert MixtralForCausalLM in TransformersToQEffModulesDict + + def test_qwen2_maps_to_qeff_qwen2(self): + from transformers.models.qwen2.modeling_qwen2 import Qwen2ForCausalLM + + assert Qwen2ForCausalLM in TransformersToQEffModulesDict + + def test_gemma_maps_to_qeff_gemma(self): + from transformers.models.gemma.modeling_gemma import GemmaForCausalLM + + assert GemmaForCausalLM in TransformersToQEffModulesDict + + def test_gemma2_maps_to_qeff_gemma2(self): + from transformers.models.gemma2.modeling_gemma2 import Gemma2ForCausalLM + + assert Gemma2ForCausalLM in TransformersToQEffModulesDict + + def test_falcon_maps_to_qeff_falcon(self): + from transformers.models.falcon.modeling_falcon import FalconForCausalLM + + assert FalconForCausalLM in TransformersToQEffModulesDict + + def test_phi3_maps_to_qeff_phi3(self): + from transformers.models.phi3.modeling_phi3 import Phi3ForCausalLM + + assert Phi3ForCausalLM in TransformersToQEffModulesDict + + def test_whisper_maps_to_qeff_whisper(self): + from transformers.models.whisper.modeling_whisper import WhisperForConditionalGeneration + + assert WhisperForConditionalGeneration in TransformersToQEffModulesDict + + def test_all_values_are_different_from_keys(self): + """QEff classes must be different from original HF classes.""" + for hf_cls, qeff_cls in TransformersToQEffModulesDict.items(): + assert hf_cls is not qeff_cls, f"{hf_cls} maps to itself - must map to a different QEff class" + + def test_all_values_are_classes(self): + for hf_cls, qeff_cls in TransformersToQEffModulesDict.items(): + assert isinstance(qeff_cls, type), f"Expected class, got {type(qeff_cls)} for key {hf_cls}" + + +# --------------------------------------------------------------------------- +# Tests: MODEL_CLASS_MAPPING +# --------------------------------------------------------------------------- + + +class TestModelClassMapping: + """MODEL_CLASS_MAPPING must map config class names to QEff class names.""" + + def test_is_not_empty(self): + assert len(MODEL_CLASS_MAPPING) > 0 + + def test_llama_config_maps_to_qeff_causal_lm(self): + assert MODEL_CLASS_MAPPING.get("LlamaConfig") == "QEFFAutoModelForCausalLM" + + def test_gpt2_config_maps_to_qeff_causal_lm(self): + assert MODEL_CLASS_MAPPING.get("GPT2Config") == "QEFFAutoModelForCausalLM" + + def test_mistral_config_maps_to_qeff_causal_lm(self): + assert MODEL_CLASS_MAPPING.get("MistralConfig") == "QEFFAutoModelForCausalLM" + + def test_qwen2_config_maps_to_qeff_causal_lm(self): + assert MODEL_CLASS_MAPPING.get("Qwen2Config") == "QEFFAutoModelForCausalLM" + + def test_phi3_config_maps_to_qeff_causal_lm(self): + assert MODEL_CLASS_MAPPING.get("Phi3Config") == "QEFFAutoModelForCausalLM" + + def test_gemma_config_maps_to_qeff_causal_lm(self): + assert MODEL_CLASS_MAPPING.get("GemmaConfig") == "QEFFAutoModelForCausalLM" + + def test_falcon_config_maps_to_qeff_causal_lm(self): + assert MODEL_CLASS_MAPPING.get("FalconConfig") == "QEFFAutoModelForCausalLM" + + def test_all_values_are_qeff_class_name_strings(self): + for key, value in MODEL_CLASS_MAPPING.items(): + assert isinstance(value, str), f"Expected string value, got {type(value)}" + assert "QEFF" in value or "QEff" in value, f"Expected QEff class name, got: {value}" + + def test_all_keys_are_config_class_name_strings(self): + for key in MODEL_CLASS_MAPPING.keys(): + assert isinstance(key, str), f"Expected string key, got {type(key)}" + assert "Config" in key, f"Expected config class name, got: {key}" + + +# --------------------------------------------------------------------------- +# Tests: EXTERNAL_MODEL_CLASS_MAPPING +# --------------------------------------------------------------------------- + + +class TestExternalModelClassMapping: + """EXTERNAL_MODEL_CLASS_MAPPING must contain external model entries.""" + + def test_external_mapping_exists_and_is_dict(self): + from QEfficient.transformers.modeling_utils import EXTERNAL_MODEL_CLASS_MAPPING + + assert isinstance(EXTERNAL_MODEL_CLASS_MAPPING, dict) + + def test_contains_grok1(self): + from QEfficient.transformers.modeling_utils import EXTERNAL_MODEL_CLASS_MAPPING + + assert "Grok1Config" in EXTERNAL_MODEL_CLASS_MAPPING + + +# --------------------------------------------------------------------------- +# Tests: DYNAMIC_SEQ_LEN_SUPPORTED_MODEL_ARCH +# --------------------------------------------------------------------------- + + +class TestDynamicSeqLenSupportedModelArch: + """DYNAMIC_SEQ_LEN_SUPPORTED_MODEL_ARCH must contain expected model types.""" + + def test_is_not_empty(self): + assert len(DYNAMIC_SEQ_LEN_SUPPORTED_MODEL_ARCH) > 0 + + def test_contains_gemma3(self): + assert "gemma3" in DYNAMIC_SEQ_LEN_SUPPORTED_MODEL_ARCH + + def test_contains_llama4(self): + assert "llama4" in DYNAMIC_SEQ_LEN_SUPPORTED_MODEL_ARCH + + def test_supports_membership_test(self): + assert hasattr(DYNAMIC_SEQ_LEN_SUPPORTED_MODEL_ARCH, "__contains__") + + def test_all_entries_are_strings(self): + for arch in DYNAMIC_SEQ_LEN_SUPPORTED_MODEL_ARCH: + assert isinstance(arch, str) + + +# --------------------------------------------------------------------------- +# Tests: _create_causal_mask +# --------------------------------------------------------------------------- + + +class TestCreateCausalMask: + """_create_causal_mask must produce correct boolean masks.""" + + def test_shape_is_correct(self): + batch, seq, target_len = 1, 4, 8 + position_ids = torch.arange(seq).unsqueeze(0) + mask = _create_causal_mask(position_ids, target_length=target_len) + assert mask.shape == (batch, 1, seq, target_len) + + def test_dtype_is_bool(self): + position_ids = torch.arange(4).unsqueeze(0) + mask = _create_causal_mask(position_ids, target_length=8) + assert mask.dtype == torch.bool + + def test_future_positions_are_masked(self): + """mask[i, j] must be True when j > i (future token = masked).""" + seq = 4 + position_ids = torch.arange(seq).unsqueeze(0) + mask = _create_causal_mask(position_ids, target_length=seq) + for i in range(seq): + for j in range(seq): + if j > i: + assert mask[0, 0, i, j].item() is True, f"Expected mask[{i},{j}]=True (future), got False" + + def test_past_positions_are_not_masked(self): + """mask[i, j] must be False when j <= i (past/current token = not masked).""" + seq = 4 + position_ids = torch.arange(seq).unsqueeze(0) + mask = _create_causal_mask(position_ids, target_length=seq) + for i in range(seq): + for j in range(i + 1): + assert mask[0, 0, i, j].item() is False, f"Expected mask[{i},{j}]=False (past), got True" + + def test_batch_size_2_works(self): + batch, seq, target_len = 2, 4, 8 + position_ids = torch.arange(seq).unsqueeze(0).expand(batch, -1) + mask = _create_causal_mask(position_ids, target_length=target_len) + assert mask.shape[0] == batch + + def test_decode_step_shape(self): + """Single-token decode step must produce correct shape.""" + batch, target_len = 1, 16 + position_ids = torch.tensor([[8]]) + mask = _create_causal_mask(position_ids, target_length=target_len) + assert mask.shape == (batch, 1, 1, target_len) + + def test_decode_step_masks_future_positions(self): + """In decode step at position 8, positions 9..15 must be masked.""" + target_len = 16 + decode_pos = 8 + position_ids = torch.tensor([[decode_pos]]) + mask = _create_causal_mask(position_ids, target_length=target_len) + # Positions 0..decode_pos must be unmasked (False) + for j in range(decode_pos + 1): + assert mask[0, 0, 0, j].item() is False, f"Position {j} should be unmasked at decode_pos={decode_pos}" + # Positions decode_pos+1..target_len-1 must be masked (True) + for j in range(decode_pos + 1, target_len): + assert mask[0, 0, 0, j].item() is True, f"Position {j} should be masked at decode_pos={decode_pos}" + + def test_sliding_window_shape_correct(self): + batch, seq, target_len = 1, 4, 8 + position_ids = torch.arange(seq).unsqueeze(0) + mask = _create_causal_mask(position_ids, target_length=target_len, sliding_window=2) + assert mask.shape == (batch, 1, seq, target_len) + + def test_no_sliding_window_none_works(self): + position_ids = torch.arange(4).unsqueeze(0) + mask = _create_causal_mask(position_ids, target_length=8, sliding_window=None) + assert mask is not None + assert mask.shape[-1] == 8 + + def test_causal_mask_is_lower_triangular(self): + """For a square mask (seq == target_len), the unmasked region must be lower triangular.""" + seq = 6 + position_ids = torch.arange(seq).unsqueeze(0) + mask = _create_causal_mask(position_ids, target_length=seq) + # mask[i, j] == False means "attend to j from position i" + # This should be lower triangular: attend to j <= i + for i in range(seq): + for j in range(seq): + expected_masked = j > i + actual_masked = mask[0, 0, i, j].item() + assert actual_masked == expected_masked, ( + f"mask[{i},{j}]: expected {expected_masked}, got {actual_masked}" + ) + + +# --------------------------------------------------------------------------- +# Tests: build_model_class_mapping +# --------------------------------------------------------------------------- + + +class TestBuildModelClassMapping: + """build_model_class_mapping must return correct config → class name mapping.""" + + def test_returns_non_empty_dict(self): + import transformers.models.auto.modeling_auto as mapping + + result = build_model_class_mapping(mapping.AutoModelForCausalLM, "QEFFAutoModelForCausalLM") + assert isinstance(result, dict) + assert len(result) > 0 + + def test_all_values_are_the_provided_class_name(self): + import transformers.models.auto.modeling_auto as mapping + + class_name = "QEFFAutoModelForCausalLM" + result = build_model_class_mapping(mapping.AutoModelForCausalLM, class_name) + for key, value in result.items(): + assert value == class_name + + def test_all_keys_are_config_class_name_strings(self): + import transformers.models.auto.modeling_auto as mapping + + result = build_model_class_mapping(mapping.AutoModelForCausalLM, "QEFFAutoModelForCausalLM") + for key in result.keys(): + assert isinstance(key, str) + assert "Config" in key + + def test_contains_llama_config(self): + import transformers.models.auto.modeling_auto as mapping + + result = build_model_class_mapping(mapping.AutoModelForCausalLM, "QEFFAutoModelForCausalLM") + assert "LlamaConfig" in result + + def test_contains_gpt2_config(self): + import transformers.models.auto.modeling_auto as mapping + + result = build_model_class_mapping(mapping.AutoModelForCausalLM, "QEFFAutoModelForCausalLM") + assert "GPT2Config" in result + + def test_contains_mistral_config(self): + import transformers.models.auto.modeling_auto as mapping + + result = build_model_class_mapping(mapping.AutoModelForCausalLM, "QEFFAutoModelForCausalLM") + assert "MistralConfig" in result + + def test_contains_qwen2_config(self): + import transformers.models.auto.modeling_auto as mapping + + result = build_model_class_mapping(mapping.AutoModelForCausalLM, "QEFFAutoModelForCausalLM") + assert "Qwen2Config" in result + + +# --------------------------------------------------------------------------- +# Tests: QEFFAutoModelForCausalLM class structure +# --------------------------------------------------------------------------- + + +class TestQEFFAutoModelForCausalLMClassStructure: + """QEFFAutoModelForCausalLM must have correct class-level attributes.""" + + def test_has_pytorch_transforms_list(self): + assert hasattr(QEFFAutoModelForCausalLM, "_pytorch_transforms") + assert isinstance(QEFFAutoModelForCausalLM._pytorch_transforms, list) + assert len(QEFFAutoModelForCausalLM._pytorch_transforms) > 0 + + def test_has_onnx_transforms_list(self): + assert hasattr(QEFFAutoModelForCausalLM, "_onnx_transforms") + assert isinstance(QEFFAutoModelForCausalLM._onnx_transforms, list) + + def test_kv_cache_transform_in_pytorch_transforms(self): + transform_names = [ + t.__name__ if hasattr(t, "__name__") else str(t) for t in QEFFAutoModelForCausalLM._pytorch_transforms + ] + assert any("KVCache" in name for name in transform_names), ( + f"KVCacheTransform not found in _pytorch_transforms: {transform_names}" + ) + + def test_custom_ops_transform_in_pytorch_transforms(self): + transform_names = [ + t.__name__ if hasattr(t, "__name__") else str(t) for t in QEFFAutoModelForCausalLM._pytorch_transforms + ] + assert any("CustomOps" in name for name in transform_names), ( + f"CustomOpsTransform not found in _pytorch_transforms: {transform_names}" + ) + + def test_has_hf_auto_class(self): + assert hasattr(QEFFAutoModelForCausalLM, "_hf_auto_class") + + def test_has_from_pretrained_classmethod(self): + assert hasattr(QEFFAutoModelForCausalLM, "from_pretrained") + assert callable(QEFFAutoModelForCausalLM.from_pretrained) + + def test_importable_from_public_api(self): + import QEfficient + + assert hasattr(QEfficient, "QEFFAutoModelForCausalLM") + assert QEfficient.QEFFAutoModelForCausalLM is QEFFAutoModelForCausalLM + + def test_continuous_batching_flag_stored(self): + from transformers import GPT2Config, GPT2LMHeadModel + + cfg = GPT2Config(n_layer=1, n_head=2, n_embd=64, vocab_size=500, n_positions=32, n_ctx=32) + model = GPT2LMHeadModel(cfg) + qeff = QEFFAutoModelForCausalLM(model, continuous_batching=True) + assert qeff.continuous_batching is True + + def test_continuous_batching_defaults_to_false(self): + from transformers import GPT2Config, GPT2LMHeadModel + + cfg = GPT2Config(n_layer=1, n_head=2, n_embd=64, vocab_size=500, n_positions=32, n_ctx=32) + model = GPT2LMHeadModel(cfg) + qeff = QEFFAutoModelForCausalLM(model) + assert qeff.continuous_batching is False + + def test_model_name_property_returns_string(self): + from transformers import GPT2Config, GPT2LMHeadModel + + cfg = GPT2Config(n_layer=1, n_head=2, n_embd=64, vocab_size=500, n_positions=32, n_ctx=32) + model = GPT2LMHeadModel(cfg) + qeff = QEFFAutoModelForCausalLM(model) + assert hasattr(qeff, "model_name") + assert isinstance(qeff.model_name, str) + assert len(qeff.model_name) > 0 + + def test_model_attribute_is_transformed_model(self): + """After construction, qeff.model must be the KV-transformed model.""" + from transformers import GPT2Config, GPT2LMHeadModel + + from QEfficient.transformers.models.gpt2.modeling_gpt2 import QEffGPT2LMHeadModel + + cfg = GPT2Config(n_layer=1, n_head=2, n_embd=64, vocab_size=500, n_positions=32, n_ctx=32) + model = GPT2LMHeadModel(cfg) + qeff = QEFFAutoModelForCausalLM(model) + assert isinstance(qeff.model, QEffGPT2LMHeadModel), f"Expected QEffGPT2LMHeadModel, got {type(qeff.model)}" + + def test_onnx_transforms_contain_fp16_clip(self): + """ONNX transforms must include FP16ClipTransform.""" + transform_names = [ + t.__name__ if hasattr(t, "__name__") else str(t) for t in QEFFAutoModelForCausalLM._onnx_transforms + ] + assert any("FP16" in name or "Clip" in name for name in transform_names), ( + f"FP16ClipTransform not found in _onnx_transforms: {transform_names}" + ) + + +# --------------------------------------------------------------------------- +# Tests: Other QEff auto model class structures +# --------------------------------------------------------------------------- + + +class TestOtherQEffAutoModelClassStructures: + """Other QEff auto model classes must have correct class-level attributes.""" + + def test_qeff_auto_model_for_speech_seq2seq_has_from_pretrained(self): + assert hasattr(QEFFAutoModelForSpeechSeq2Seq, "from_pretrained") + assert callable(QEFFAutoModelForSpeechSeq2Seq.from_pretrained) + + def test_qeff_auto_model_for_speech_seq2seq_has_pytorch_transforms(self): + assert hasattr(QEFFAutoModelForSpeechSeq2Seq, "_pytorch_transforms") + assert isinstance(QEFFAutoModelForSpeechSeq2Seq._pytorch_transforms, list) + + def test_qeff_auto_model_for_speech_seq2seq_has_hf_auto_class(self): + assert hasattr(QEFFAutoModelForSpeechSeq2Seq, "_hf_auto_class") + + def test_qeff_auto_model_has_from_pretrained(self): + assert hasattr(QEFFAutoModel, "from_pretrained") + assert callable(QEFFAutoModel.from_pretrained) + + def test_qeff_auto_model_has_pytorch_transforms(self): + assert hasattr(QEFFAutoModel, "_pytorch_transforms") + + def test_qeff_auto_model_has_hf_auto_class(self): + assert hasattr(QEFFAutoModel, "_hf_auto_class") + + def test_qeff_auto_model_for_seq_classification_has_from_pretrained(self): + assert hasattr(QEFFAutoModelForSequenceClassification, "from_pretrained") + assert callable(QEFFAutoModelForSequenceClassification.from_pretrained) + + def test_qeff_auto_model_for_seq_classification_has_pytorch_transforms(self): + assert hasattr(QEFFAutoModelForSequenceClassification, "_pytorch_transforms") + + def test_qeff_auto_model_for_seq_classification_has_hf_auto_class(self): + assert hasattr(QEFFAutoModelForSequenceClassification, "_hf_auto_class") + + def test_misclassified_map_exists(self): + try: + from QEfficient.transformers.models.modeling_auto import ( + MISCLASSIFIED_CAUSAL_LM_TO_QEFF_AUTO_CLASS_MAP, + ) + + assert isinstance(MISCLASSIFIED_CAUSAL_LM_TO_QEFF_AUTO_CLASS_MAP, dict) + except ImportError: + pytest.skip("MISCLASSIFIED_CAUSAL_LM_TO_QEFF_AUTO_CLASS_MAP not available") + + def test_qeff_auto_model_for_seq_classification_wraps_bert(self): + """QEFFAutoModelForSequenceClassification must wrap BERT without error.""" + from transformers import BertConfig, BertForSequenceClassification + + cfg = BertConfig( + num_hidden_layers=1, + num_attention_heads=2, + hidden_size=64, + intermediate_size=128, + vocab_size=500, + max_position_embeddings=64, + num_labels=3, + ) + model = BertForSequenceClassification(cfg) + qeff = QEFFAutoModelForSequenceClassification(model) + assert qeff is not None + assert hasattr(qeff, "model") + + def test_qeff_auto_model_wraps_bert(self): + """QEFFAutoModel must wrap BERT without error.""" + from transformers import BertConfig, BertModel + + cfg = BertConfig( + num_hidden_layers=1, + num_attention_heads=2, + hidden_size=64, + intermediate_size=128, + vocab_size=500, + max_position_embeddings=64, + ) + model = BertModel(cfg) + qeff = QEFFAutoModel(model) + assert qeff is not None + assert hasattr(qeff, "model") + + +# --------------------------------------------------------------------------- +# Tests: QEFFAutoModelForCausalLM error paths +# --------------------------------------------------------------------------- + + +class TestQEFFAutoModelForCausalLMErrorPaths: + """QEFFAutoModelForCausalLM must raise appropriate errors for invalid inputs.""" + + def test_non_causal_lm_model_raises_assertion_error(self): + """Passing a non-CausalLM model must raise AssertionError or TypeError.""" + from transformers import BertConfig, BertForSequenceClassification + + cfg = BertConfig( + num_hidden_layers=1, + num_attention_heads=2, + hidden_size=64, + intermediate_size=128, + vocab_size=500, + max_position_embeddings=64, + num_labels=3, + ) + model = BertForSequenceClassification(cfg) + with pytest.raises((AssertionError, TypeError, ValueError)): + QEFFAutoModelForCausalLM(model) + + def test_bert_model_raises_error_when_passed_to_causal_lm(self): + """BertModel (not CausalLM) must raise an error.""" + from transformers import BertConfig, BertModel + + cfg = BertConfig( + num_hidden_layers=1, + num_attention_heads=2, + hidden_size=64, + intermediate_size=128, + vocab_size=500, + max_position_embeddings=64, + ) + model = BertModel(cfg) + with pytest.raises((AssertionError, TypeError, ValueError)): + QEFFAutoModelForCausalLM(model) + + def test_none_model_raises_error(self): + """Passing None must raise an error.""" + with pytest.raises((AssertionError, TypeError, AttributeError)): + QEFFAutoModelForCausalLM(None) + + +# --------------------------------------------------------------------------- +# Tests: QEFFAutoModelForSpeechSeq2Seq error paths +# --------------------------------------------------------------------------- + + +class TestQEFFAutoModelForSpeechSeq2SeqErrorPaths: + """QEFFAutoModelForSpeechSeq2Seq must raise appropriate errors for invalid inputs.""" + + def test_non_speech_model_raises_error(self): + """Passing a non-speech model must raise AssertionError or TypeError.""" + from transformers import GPT2Config, GPT2LMHeadModel + + cfg = GPT2Config(n_layer=1, n_head=2, n_embd=64, vocab_size=500, n_positions=32, n_ctx=32) + model = GPT2LMHeadModel(cfg) + with pytest.raises((AssertionError, TypeError, ValueError)): + QEFFAutoModelForSpeechSeq2Seq(model) + + def test_bert_model_raises_error_when_passed_to_speech_seq2seq(self): + """BertModel must raise an error when passed to QEFFAutoModelForSpeechSeq2Seq.""" + from transformers import BertConfig, BertModel + + cfg = BertConfig( + num_hidden_layers=1, + num_attention_heads=2, + hidden_size=64, + intermediate_size=128, + vocab_size=500, + max_position_embeddings=64, + ) + model = BertModel(cfg) + with pytest.raises((AssertionError, TypeError, ValueError)): + QEFFAutoModelForSpeechSeq2Seq(model) + + +# --------------------------------------------------------------------------- +# Tests: MODEL_CLASS_MAPPING completeness +# --------------------------------------------------------------------------- + + +class TestModelClassMappingCompleteness: + """MODEL_CLASS_MAPPING must include VLM config classes.""" + + def test_contains_llava_config(self): + from QEfficient.transformers.modeling_utils import MODEL_CLASS_MAPPING + + # LlavaConfig should map to QEFFAutoModelForImageTextToText + assert "LlavaConfig" in MODEL_CLASS_MAPPING, "LlavaConfig missing from MODEL_CLASS_MAPPING" + + def test_llava_config_maps_to_vlm_class(self): + from QEfficient.transformers.modeling_utils import MODEL_CLASS_MAPPING + + if "LlavaConfig" in MODEL_CLASS_MAPPING: + assert ( + "ImageTextToText" in MODEL_CLASS_MAPPING["LlavaConfig"] + or "CausalLM" in MODEL_CLASS_MAPPING["LlavaConfig"] + ), f"LlavaConfig maps to unexpected class: {MODEL_CLASS_MAPPING['LlavaConfig']}" + + def test_all_values_are_qeff_class_names(self): + from QEfficient.transformers.modeling_utils import MODEL_CLASS_MAPPING + + for key, value in MODEL_CLASS_MAPPING.items(): + assert isinstance(value, str), f"Expected string value for key '{key}', got {type(value)}" + assert "QEFF" in value or "QEff" in value, f"Expected QEff class name for key '{key}', got: {value}" + + +# --------------------------------------------------------------------------- +# Tests: SPECIALIZED_DISAGG_SERVING_MODEL_ARCH +# --------------------------------------------------------------------------- + + +class TestSpecializedDisaggServingModelArch: + """SPECIALIZED_DISAGG_SERVING_MODEL_ARCH must contain expected model types.""" + + def test_exists_and_is_set_or_collection(self): + from QEfficient.transformers.modeling_utils import SPECIALIZED_DISAGG_SERVING_MODEL_ARCH + + assert hasattr(SPECIALIZED_DISAGG_SERVING_MODEL_ARCH, "__contains__") + + def test_contains_gpt_oss(self): + from QEfficient.transformers.modeling_utils import SPECIALIZED_DISAGG_SERVING_MODEL_ARCH + + assert "gpt_oss" in SPECIALIZED_DISAGG_SERVING_MODEL_ARCH + + def test_all_entries_are_strings(self): + from QEfficient.transformers.modeling_utils import SPECIALIZED_DISAGG_SERVING_MODEL_ARCH + + for arch in SPECIALIZED_DISAGG_SERVING_MODEL_ARCH: + assert isinstance(arch, str), f"Expected string, got {type(arch)}: {arch}" diff --git a/tests/unit_test/utils/test_padding_and_shapes.py b/tests/unit_test/utils/test_padding_and_shapes.py new file mode 100644 index 000000000..266d0f6fe --- /dev/null +++ b/tests/unit_test/utils/test_padding_and_shapes.py @@ -0,0 +1,615 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +# SPDX-License-Identifier: BSD-3-Clause +# +# ----------------------------------------------------------------------------- +""" +Tests for utility functions: get_padding_shape_from_config, sampler_utils, hash_utils. + +Tests verify: + - get_padding_shape_from_config: correct KV cache shapes for various model configs + - get_sampling_inputs_and_outputs: correct input/output names for sampler + - hash_dict_params: deterministic, correct length, different configs → different hashes + +All tests run on CPU only. +""" + +import pytest +import torch +from transformers import ( + GPT2Config, + LlamaConfig, + MistralConfig, +) + +from QEfficient.utils.constants import HASH_HEXDIGEST_STR_LEN +from QEfficient.utils.hash_utils import hash_dict_params +from QEfficient.utils.sampler_utils import get_sampling_inputs_and_outputs + +# --------------------------------------------------------------------------- +# Helpers: get_padding_shape_from_config +# --------------------------------------------------------------------------- + + +def _get_padding_shape(config, batch_size=1, seq_len=32): + """Import and call get_padding_shape_from_config.""" + from QEfficient.utils import get_padding_shape_from_config + + return get_padding_shape_from_config(config, batch_size, seq_len) + + +# --------------------------------------------------------------------------- +# Tests: get_padding_shape_from_config +# --------------------------------------------------------------------------- + + +@pytest.mark.cpu_only +class TestGetPaddingShapeFromConfig: + """get_padding_shape_from_config must return correct KV cache shapes.""" + + def test_llama_returns_correct_shape(self): + """Llama: shape must be [batch, n_kv_heads, seq_len, head_dim].""" + cfg = LlamaConfig( + num_hidden_layers=2, + num_attention_heads=4, + num_key_value_heads=4, + hidden_size=64, + intermediate_size=128, + vocab_size=500, + max_position_embeddings=64, + ) + shape = _get_padding_shape(cfg, batch_size=1, seq_len=32) + assert len(shape) == 4, f"Expected 4D shape, got {len(shape)}D: {shape}" + assert shape[0] == 1 # batch_size + assert shape[1] == 4 # n_kv_heads + assert shape[2] == 32 # seq_len + assert shape[3] == 16 # head_dim = hidden_size / num_attention_heads = 64/4 + + def test_gpt2_returns_correct_shape(self): + """GPT2: shape must be [batch, n_heads, seq_len, head_dim].""" + cfg = GPT2Config( + n_layer=2, + n_head=4, + n_embd=64, + vocab_size=500, + n_positions=64, + n_ctx=64, + ) + shape = _get_padding_shape(cfg, batch_size=1, seq_len=32) + assert len(shape) == 4 + assert shape[0] == 1 + assert shape[2] == 32 + + def test_mistral_gqa_returns_correct_kv_heads(self): + """Mistral with GQA: n_kv_heads must be less than n_heads.""" + cfg = MistralConfig( + num_hidden_layers=2, + num_attention_heads=8, + num_key_value_heads=2, # GQA: 2 KV heads for 8 query heads + hidden_size=64, + intermediate_size=128, + vocab_size=500, + max_position_embeddings=64, + ) + shape = _get_padding_shape(cfg, batch_size=1, seq_len=32) + assert len(shape) == 4 + assert shape[1] == 2, f"Expected 2 KV heads for GQA, got {shape[1]}" + + def test_shape_has_4_dimensions(self): + """Shape must always have exactly 4 dimensions for standard models.""" + cfg = LlamaConfig( + num_hidden_layers=2, + num_attention_heads=2, + num_key_value_heads=2, + hidden_size=64, + intermediate_size=128, + vocab_size=500, + max_position_embeddings=64, + ) + shape = _get_padding_shape(cfg, batch_size=2, seq_len=16) + assert len(shape) == 4 + + def test_batch_size_reflected_in_shape(self): + """Batch size must be reflected in the first dimension of the shape.""" + cfg = LlamaConfig( + num_hidden_layers=2, + num_attention_heads=2, + num_key_value_heads=2, + hidden_size=64, + intermediate_size=128, + vocab_size=500, + max_position_embeddings=64, + ) + shape = _get_padding_shape(cfg, batch_size=4, seq_len=32) + assert shape[0] == 4 + + def test_seq_len_reflected_in_shape(self): + """Sequence length must be reflected in the third dimension of the shape.""" + cfg = LlamaConfig( + num_hidden_layers=2, + num_attention_heads=2, + num_key_value_heads=2, + hidden_size=64, + intermediate_size=128, + vocab_size=500, + max_position_embeddings=64, + ) + shape = _get_padding_shape(cfg, batch_size=1, seq_len=64) + assert shape[2] == 64 + + def test_head_dim_is_hidden_size_divided_by_num_heads(self): + """head_dim must equal hidden_size / num_attention_heads.""" + hidden_size = 128 + num_heads = 8 + cfg = LlamaConfig( + num_hidden_layers=2, + num_attention_heads=num_heads, + num_key_value_heads=num_heads, + hidden_size=hidden_size, + intermediate_size=256, + vocab_size=500, + max_position_embeddings=64, + ) + shape = _get_padding_shape(cfg, batch_size=1, seq_len=32) + expected_head_dim = hidden_size // num_heads + assert shape[3] == expected_head_dim, f"Expected head_dim={expected_head_dim}, got {shape[3]}" + + +# --------------------------------------------------------------------------- +# Tests: get_sampling_inputs_and_outputs +# --------------------------------------------------------------------------- + + +@pytest.mark.cpu_only +class TestSamplerUtils: + """get_sampling_inputs_and_outputs must return correct input/output names.""" + + def _make_base_inputs(self, batch=1, seq_len=8): + """Create minimal example inputs for sampler utils.""" + return { + "input_ids": torch.zeros((batch, seq_len), dtype=torch.int64), + "position_ids": torch.arange(seq_len).unsqueeze(0).expand(batch, -1), + } + + def _make_base_dynamic_axes(self): + return { + "input_ids": {0: "batch_size", 1: "seq_len"}, + "position_ids": {0: "batch_size", 1: "seq_len"}, + } + + def test_get_sampling_inputs_returns_temperatures(self): + """Sampler inputs must include 'temperatures'.""" + inputs = self._make_base_inputs() + output_names = ["logits"] + dynamic_axes = self._make_base_dynamic_axes() + qaic_config = {"max_top_k_ids": 512} + + updated_inputs, _, _ = get_sampling_inputs_and_outputs( + example_inputs=inputs, + output_names=output_names, + dynamic_axes=dynamic_axes, + continuous_batching=False, + vocab_size=500, + qaic_config=qaic_config, + ) + assert "temperatures" in updated_inputs + + def test_get_sampling_inputs_returns_top_ks(self): + """Sampler inputs must include 'top_ks'.""" + inputs = self._make_base_inputs() + output_names = ["logits"] + dynamic_axes = self._make_base_dynamic_axes() + qaic_config = {"max_top_k_ids": 512} + + updated_inputs, _, _ = get_sampling_inputs_and_outputs( + example_inputs=inputs, + output_names=output_names, + dynamic_axes=dynamic_axes, + continuous_batching=False, + vocab_size=500, + qaic_config=qaic_config, + ) + assert "top_ks" in updated_inputs + + def test_get_sampling_inputs_returns_top_ps(self): + """Sampler inputs must include 'top_ps'.""" + inputs = self._make_base_inputs() + output_names = ["logits"] + dynamic_axes = self._make_base_dynamic_axes() + qaic_config = {"max_top_k_ids": 512} + + updated_inputs, _, _ = get_sampling_inputs_and_outputs( + example_inputs=inputs, + output_names=output_names, + dynamic_axes=dynamic_axes, + continuous_batching=False, + vocab_size=500, + qaic_config=qaic_config, + ) + assert "top_ps" in updated_inputs + + def test_get_sampling_inputs_returns_repetition_penalties(self): + """Sampler inputs must include 'repetition_penalties'.""" + inputs = self._make_base_inputs() + output_names = ["logits"] + dynamic_axes = self._make_base_dynamic_axes() + qaic_config = {"max_top_k_ids": 512} + + updated_inputs, _, _ = get_sampling_inputs_and_outputs( + example_inputs=inputs, + output_names=output_names, + dynamic_axes=dynamic_axes, + continuous_batching=False, + vocab_size=500, + qaic_config=qaic_config, + ) + assert "repetition_penalties" in updated_inputs + + def test_get_sampling_inputs_returns_random_numbers(self): + """Sampler inputs must include 'random_numbers'.""" + inputs = self._make_base_inputs() + output_names = ["logits"] + dynamic_axes = self._make_base_dynamic_axes() + qaic_config = {"max_top_k_ids": 512} + + updated_inputs, _, _ = get_sampling_inputs_and_outputs( + example_inputs=inputs, + output_names=output_names, + dynamic_axes=dynamic_axes, + continuous_batching=False, + vocab_size=500, + qaic_config=qaic_config, + ) + assert "random_numbers" in updated_inputs + + def test_get_sampling_outputs_includes_retained_state(self): + """Sampler outputs must include retained state buffers.""" + inputs = self._make_base_inputs() + output_names = ["logits"] + dynamic_axes = self._make_base_dynamic_axes() + qaic_config = {"max_top_k_ids": 512} + + _, updated_output_names, _ = get_sampling_inputs_and_outputs( + example_inputs=inputs, + output_names=output_names, + dynamic_axes=dynamic_axes, + continuous_batching=False, + vocab_size=500, + qaic_config=qaic_config, + ) + # Must include retained state outputs + retained_state_outputs = [n for n in updated_output_names if "_RetainedState" in n] + assert len(retained_state_outputs) > 0, "Sampler must add RetainedState outputs" + + def test_get_sampling_inputs_includes_last_accepted_output_tokens(self): + """Sampler inputs must include 'last_accepted_output_tokens'.""" + inputs = self._make_base_inputs() + output_names = ["logits"] + dynamic_axes = self._make_base_dynamic_axes() + qaic_config = {"max_top_k_ids": 512} + + updated_inputs, _, _ = get_sampling_inputs_and_outputs( + example_inputs=inputs, + output_names=output_names, + dynamic_axes=dynamic_axes, + continuous_batching=False, + vocab_size=500, + qaic_config=qaic_config, + ) + assert "last_accepted_output_tokens" in updated_inputs + + def test_get_sampling_dynamic_axes_updated(self): + """Dynamic axes must be updated for all new sampler inputs.""" + inputs = self._make_base_inputs() + output_names = ["logits"] + dynamic_axes = self._make_base_dynamic_axes() + qaic_config = {"max_top_k_ids": 512} + + _, _, updated_axes = get_sampling_inputs_and_outputs( + example_inputs=inputs, + output_names=output_names, + dynamic_axes=dynamic_axes, + continuous_batching=False, + vocab_size=500, + qaic_config=qaic_config, + ) + assert "temperatures" in updated_axes + assert "top_ks" in updated_axes + assert "top_ps" in updated_axes + + def test_get_sampling_inputs_tensor_shapes_are_correct(self): + """Sampler input tensors must have correct shapes (batch dim >= 1).""" + batch = 1 + inputs = self._make_base_inputs(batch=batch) + output_names = ["logits"] + dynamic_axes = self._make_base_dynamic_axes() + qaic_config = {"max_top_k_ids": 512} + + updated_inputs, _, _ = get_sampling_inputs_and_outputs( + example_inputs=inputs, + output_names=output_names, + dynamic_axes=dynamic_axes, + continuous_batching=False, + vocab_size=500, + qaic_config=qaic_config, + ) + # temperatures must be a tensor with at least 1 element + assert updated_inputs["temperatures"].numel() >= 1 + # top_ks must be a tensor with at least 1 element + assert updated_inputs["top_ks"].numel() >= 1 + # top_ps must be a tensor with at least 1 element + assert updated_inputs["top_ps"].numel() >= 1 + + +# --------------------------------------------------------------------------- +# Tests: hash_utils +# --------------------------------------------------------------------------- + + +@pytest.mark.cpu_only +class TestHashUtils: + """hash_dict_params must be deterministic, correct length, and collision-resistant.""" + + def test_compute_hash_returns_string(self): + """hash_dict_params must return a string.""" + result = hash_dict_params({"key": "value"}) + assert isinstance(result, str) + + def test_compute_hash_is_deterministic(self): + """Same input must always produce the same hash.""" + params = {"model": "llama", "layers": 2, "heads": 4} + hash1 = hash_dict_params(params) + hash2 = hash_dict_params(params) + assert hash1 == hash2, "hash_dict_params must be deterministic" + + def test_different_configs_produce_different_hashes(self): + """Different configs must produce different hashes.""" + params1 = {"model": "llama", "layers": 2} + params2 = {"model": "llama", "layers": 4} + hash1 = hash_dict_params(params1) + hash2 = hash_dict_params(params2) + assert hash1 != hash2, "Different configs must produce different hashes" + + def test_hash_length_is_correct(self): + """Hash must have length HASH_HEXDIGEST_STR_LEN (16).""" + result = hash_dict_params({"key": "value"}) + assert len(result) == HASH_HEXDIGEST_STR_LEN, ( + f"Expected hash length {HASH_HEXDIGEST_STR_LEN}, got {len(result)}" + ) + + def test_hash_is_hexadecimal(self): + """Hash must consist of hexadecimal characters only.""" + result = hash_dict_params({"key": "value", "num": 42}) + assert all(c in "0123456789abcdef" for c in result), f"Hash must be hexadecimal, got: {result}" + + def test_empty_dict_produces_valid_hash(self): + """Empty dict must produce a valid hash.""" + result = hash_dict_params({}) + assert isinstance(result, str) + assert len(result) == HASH_HEXDIGEST_STR_LEN + + def test_nested_dict_produces_valid_hash(self): + """Nested dict must produce a valid hash.""" + params = {"outer": {"inner": "value"}, "num": 42} + result = hash_dict_params(params) + assert isinstance(result, str) + assert len(result) == HASH_HEXDIGEST_STR_LEN + + def test_order_independent_hashing(self): + """Dict with same keys in different order must produce the same hash (sort_keys=True).""" + params1 = {"b": 2, "a": 1} + params2 = {"a": 1, "b": 2} + hash1 = hash_dict_params(params1) + hash2 = hash_dict_params(params2) + assert hash1 == hash2, "Hash must be order-independent (sort_keys=True)" + + def test_custom_hash_size(self): + """Custom hash_string_size must be respected.""" + result = hash_dict_params({"key": "value"}, hash_string_size=8) + assert len(result) == 8 + + +# --------------------------------------------------------------------------- +# Tests: process_ccl_specializations (GAP H) +# --------------------------------------------------------------------------- + + +@pytest.mark.cpu_only +class TestCheckCCLSpecializations: + """Tests for process_ccl_specializations and related CCL utility functions.""" + + def test_process_ccl_specializations_returns_three_values(self): + """process_ccl_specializations must return (ccl_prefill, ccl_decode, ctx_len).""" + from QEfficient.utils.check_ccl_specializations import process_ccl_specializations + + result = process_ccl_specializations(None, None, ctx_len=4096, prefill_seq_len=128) + assert len(result) == 3 + + def test_process_ccl_specializations_returns_lists(self): + """process_ccl_specializations must return lists for prefill and decode.""" + from QEfficient.utils.check_ccl_specializations import process_ccl_specializations + + ccl_prefill, ccl_decode, ctx_len = process_ccl_specializations(None, None, ctx_len=4096, prefill_seq_len=128) + assert isinstance(ccl_prefill, list) + assert isinstance(ccl_decode, list) + + def test_process_ccl_specializations_lists_not_empty(self): + """process_ccl_specializations must return non-empty lists.""" + from QEfficient.utils.check_ccl_specializations import process_ccl_specializations + + ccl_prefill, ccl_decode, ctx_len = process_ccl_specializations(None, None, ctx_len=4096, prefill_seq_len=128) + assert len(ccl_prefill) > 0 + assert len(ccl_decode) > 0 + + def test_process_ccl_specializations_last_element_leq_ctx_len(self): + """Last element of CCL lists must be <= ctx_len.""" + from QEfficient.utils.check_ccl_specializations import process_ccl_specializations + + ctx_len = 4096 + ccl_prefill, ccl_decode, returned_ctx_len = process_ccl_specializations( + None, None, ctx_len=ctx_len, prefill_seq_len=128 + ) + assert ccl_prefill[-1] <= ctx_len + assert ccl_decode[-1] <= ctx_len + + def test_process_ccl_specializations_with_explicit_lists(self): + """process_ccl_specializations with explicit lists must validate and return them.""" + from QEfficient.utils.check_ccl_specializations import process_ccl_specializations + + ccl_prefill, ccl_decode, ctx_len = process_ccl_specializations( + [512, 1024], [1024, 2048], ctx_len=4096, prefill_seq_len=128 + ) + assert isinstance(ccl_prefill, list) + assert isinstance(ccl_decode, list) + + def test_process_ccl_specializations_with_only_prefill(self): + """process_ccl_specializations with only prefill list must fill decode with ctx_len.""" + from QEfficient.utils.check_ccl_specializations import process_ccl_specializations + + ccl_prefill, ccl_decode, ctx_len = process_ccl_specializations( + [512, 1024], None, ctx_len=4096, prefill_seq_len=128 + ) + assert isinstance(ccl_prefill, list) + assert isinstance(ccl_decode, list) + assert len(ccl_decode) > 0 + + def test_process_ccl_specializations_with_only_decode(self): + """process_ccl_specializations with only decode list must fill prefill with ctx_len.""" + from QEfficient.utils.check_ccl_specializations import process_ccl_specializations + + ccl_prefill, ccl_decode, ctx_len = process_ccl_specializations( + None, [1024, 2048], ctx_len=4096, prefill_seq_len=128 + ) + assert isinstance(ccl_prefill, list) + assert isinstance(ccl_decode, list) + assert len(ccl_prefill) > 0 + + def test_process_ccl_specializations_prefill_seq_len_1(self): + """With prefill_seq_len=1, prefill and decode lists must be identical.""" + from QEfficient.utils.check_ccl_specializations import process_ccl_specializations + + ccl_prefill, ccl_decode, ctx_len = process_ccl_specializations(None, None, ctx_len=4096, prefill_seq_len=1) + assert ccl_prefill == ccl_decode, "With prefill_seq_len=1, prefill and decode CCL lists must be identical" + + +@pytest.mark.cpu_only +class TestAutomaticCCLGeneration: + """Tests for automatic_ccl_generation utility function.""" + + def test_automatic_ccl_generation_returns_three_values(self): + """automatic_ccl_generation must return (prefill_list, decode_list, mapped_cl).""" + from QEfficient.utils.check_ccl_specializations import automatic_ccl_generation + + result = automatic_ccl_generation(ctx_len=4096, prefill_seq_len=128) + assert len(result) == 3 + + def test_automatic_ccl_generation_returns_lists(self): + """automatic_ccl_generation must return lists.""" + from QEfficient.utils.check_ccl_specializations import automatic_ccl_generation + + prefill_list, decode_list, mapped_cl = automatic_ccl_generation(ctx_len=4096, prefill_seq_len=128) + assert isinstance(prefill_list, list) + assert isinstance(decode_list, list) + + def test_automatic_ccl_generation_mapped_cl_is_multiple_of_1024(self): + """mapped_cl must be a multiple of 1024.""" + from QEfficient.utils.check_ccl_specializations import automatic_ccl_generation + + _, _, mapped_cl = automatic_ccl_generation(ctx_len=3000, prefill_seq_len=128) + assert mapped_cl % 1024 == 0, f"mapped_cl={mapped_cl} must be a multiple of 1024" + + def test_automatic_ccl_generation_small_ctx_len(self): + """automatic_ccl_generation with small ctx_len must return valid lists.""" + from QEfficient.utils.check_ccl_specializations import automatic_ccl_generation + + prefill_list, decode_list, mapped_cl = automatic_ccl_generation(ctx_len=512, prefill_seq_len=128) + assert len(prefill_list) > 0 + assert len(decode_list) > 0 + + def test_automatic_ccl_generation_zero_ctx_len(self): + """automatic_ccl_generation with ctx_len=0 must return valid lists.""" + from QEfficient.utils.check_ccl_specializations import automatic_ccl_generation + + prefill_list, decode_list, mapped_cl = automatic_ccl_generation(ctx_len=0, prefill_seq_len=128) + assert len(prefill_list) > 0 + assert len(decode_list) > 0 + + +@pytest.mark.cpu_only +class TestCCLHelperFunctions: + """Tests for CCL helper functions: next_multiple_of_1024, build_doubling_list, etc.""" + + def test_next_multiple_of_1024_rounds_up(self): + """next_multiple_of_1024 must round up to the next multiple of 1024.""" + from QEfficient.utils.check_ccl_specializations import next_multiple_of_1024 + + assert next_multiple_of_1024(1) == 1024 + assert next_multiple_of_1024(1024) == 1024 + assert next_multiple_of_1024(1025) == 2048 + assert next_multiple_of_1024(2048) == 2048 + assert next_multiple_of_1024(2049) == 3072 + + def test_next_multiple_of_1024_zero_or_negative(self): + """next_multiple_of_1024 with n<=0 must return 0.""" + from QEfficient.utils.check_ccl_specializations import next_multiple_of_1024 + + assert next_multiple_of_1024(0) == 0 + assert next_multiple_of_1024(-1) == 0 + + def test_build_doubling_list_basic(self): + """build_doubling_list must return a doubling sequence.""" + from QEfficient.utils.check_ccl_specializations import build_doubling_list + + result = build_doubling_list(start=1024, limit=8192, max_elements=5) + assert result[0] == 1024 + # Each element must be double the previous + for i in range(1, len(result)): + assert result[i] == result[i - 1] * 2 or result[i] <= 8192 + + def test_build_doubling_list_respects_max_elements(self): + """build_doubling_list must not exceed max_elements.""" + from QEfficient.utils.check_ccl_specializations import build_doubling_list + + result = build_doubling_list(start=1024, limit=1024 * 1024, max_elements=4) + assert len(result) <= 4 + + def test_build_doubling_list_respects_limit(self): + """build_doubling_list must not exceed limit.""" + from QEfficient.utils.check_ccl_specializations import build_doubling_list + + limit = 4096 + result = build_doubling_list(start=1024, limit=limit, max_elements=10) + for val in result: + assert val <= limit, f"Value {val} exceeds limit {limit}" + + def test_build_doubling_list_with_last_value(self): + """build_doubling_list with last_value must end with that value.""" + from QEfficient.utils.check_ccl_specializations import build_doubling_list + + result = build_doubling_list(start=1024, limit=8192, max_elements=5, last_value=8192) + assert result[-1] == 8192 + + def test_is_power_of_two(self): + """is_power_of_two must correctly identify powers of two.""" + from QEfficient.utils.check_ccl_specializations import is_power_of_two + + assert is_power_of_two(1) + assert is_power_of_two(2) + assert is_power_of_two(4) + assert is_power_of_two(1024) + assert is_power_of_two(4096) + assert not is_power_of_two(3) + assert not is_power_of_two(5) + assert not is_power_of_two(0) + assert not is_power_of_two(-1) + + def test_floor_to_1000(self): + """floor_to_1000 must floor to the nearest lower multiple of 1000.""" + from QEfficient.utils.check_ccl_specializations import floor_to_1000 + + assert floor_to_1000(1500) == 1000 + assert floor_to_1000(2000) == 2000 + assert floor_to_1000(999) == 0 + assert floor_to_1000(0) == 0 + assert floor_to_1000(-1) == 0