From 58b8e6e79ffe19ad18663151d39539a519810e51 Mon Sep 17 00:00:00 2001 From: Vinayak Baddi Date: Wed, 6 Aug 2025 19:00:00 +0000 Subject: [PATCH 01/37] [QEff]: Add gpt_oss Signed-off-by: vbaddi Signed-off-by: Onkar Chougule --- QEfficient/transformers/models/modeling_auto.py | 1 + 1 file changed, 1 insertion(+) diff --git a/QEfficient/transformers/models/modeling_auto.py b/QEfficient/transformers/models/modeling_auto.py index 8edc1f3f0..9543346b5 100644 --- a/QEfficient/transformers/models/modeling_auto.py +++ b/QEfficient/transformers/models/modeling_auto.py @@ -2354,6 +2354,7 @@ class QEFFAutoModelForCausalLM(QEFFBaseModel): CustomOpsTransform, KVCacheTransform, SplitGateUpWeightsTransform, + SplitGateUpWeightsTransformGPTOSS, KVCacheExternalModuleMapperTransform, ] From 582fc177131dd41d91ff447b564cc435bf033fea Mon Sep 17 00:00:00 2001 From: Vinayak Baddi Date: Thu, 7 Aug 2025 14:34:08 +0000 Subject: [PATCH 02/37] nit: update modeling and make transform uniform Signed-off-by: vbaddi Signed-off-by: Onkar Chougule --- QEfficient/transformers/models/modeling_auto.py | 1 - pyproject.toml | 4 ++-- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/QEfficient/transformers/models/modeling_auto.py b/QEfficient/transformers/models/modeling_auto.py index 9543346b5..8edc1f3f0 100644 --- a/QEfficient/transformers/models/modeling_auto.py +++ b/QEfficient/transformers/models/modeling_auto.py @@ -2354,7 +2354,6 @@ class QEFFAutoModelForCausalLM(QEFFBaseModel): CustomOpsTransform, KVCacheTransform, SplitGateUpWeightsTransform, - SplitGateUpWeightsTransformGPTOSS, KVCacheExternalModuleMapperTransform, ] diff --git a/pyproject.toml b/pyproject.toml index 8e179ab4a..cc38f4bf8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -22,8 +22,8 @@ dependencies = [ "transformers==4.55.0", "huggingface-hub==0.34.0", "hf_transfer==0.1.9", - "peft==0.13.2", - "datasets==2.20.0", + "peft", + "datasets", "fsspec==2023.6.0", "multidict==6.0.4", "urllib3<2", From 6352ac2a88aa6a027c095d1993a7452a42df56f0 Mon Sep 17 00:00:00 2001 From: Onkar Chougule Date: Thu, 7 Aug 2025 15:23:21 +0530 Subject: [PATCH 03/37] apirunner change Signed-off-by: Onkar Chougule --- QEfficient/transformers/models/modeling_auto.py | 5 +++++ QEfficient/utils/generate_inputs.py | 1 + 2 files changed, 6 insertions(+) diff --git a/QEfficient/transformers/models/modeling_auto.py b/QEfficient/transformers/models/modeling_auto.py index 8edc1f3f0..f4a59a2c5 100644 --- a/QEfficient/transformers/models/modeling_auto.py +++ b/QEfficient/transformers/models/modeling_auto.py @@ -3082,6 +3082,11 @@ def compile( for kv in ["key", "value"]: custom_io[f"past_{kv}.{i}{suffix}"] = kv_cache_dtype + # HACK for now + if self.model.config.model_type == "gpt_oss": + for spec in specializations: + spec.update({"sliding_window": 128}) + qpc_path = self._compile( onnx_path=onnx_path, compile_dir=compile_dir, diff --git a/QEfficient/utils/generate_inputs.py b/QEfficient/utils/generate_inputs.py index 95474acfd..5cacd6f84 100644 --- a/QEfficient/utils/generate_inputs.py +++ b/QEfficient/utils/generate_inputs.py @@ -92,6 +92,7 @@ def prepare_pytorch_inputs(self): inputs["batch_index"] = torch.arange(self.full_batch_size).view(-1, 1) past_key_values = [] + sliding_padding_shape = self.padding_shape[:2] + [self.config.sliding_window] + self.padding_shape[-1] for i in range(self.n_layer): if ( all(hasattr(self.config, attr) for attr in ["sliding_window", "layer_types"]) From 296dc9a845d2df2ad8f44755763a0b36c0767f31 Mon Sep 17 00:00:00 2001 From: Onkar Chougule Date: Thu, 7 Aug 2025 19:24:24 +0530 Subject: [PATCH 04/37] added test along with simplified Hybridcache Signed-off-by: Onkar Chougule --- QEfficient/utils/generate_inputs.py | 2 +- tests/test_gpt.py | 61 +++++++++++++++++++++++++++++ 2 files changed, 62 insertions(+), 1 deletion(-) create mode 100644 tests/test_gpt.py diff --git a/QEfficient/utils/generate_inputs.py b/QEfficient/utils/generate_inputs.py index 5cacd6f84..cb2a68fa9 100644 --- a/QEfficient/utils/generate_inputs.py +++ b/QEfficient/utils/generate_inputs.py @@ -92,7 +92,7 @@ def prepare_pytorch_inputs(self): inputs["batch_index"] = torch.arange(self.full_batch_size).view(-1, 1) past_key_values = [] - sliding_padding_shape = self.padding_shape[:2] + [self.config.sliding_window] + self.padding_shape[-1] + sliding_padding_shape = self.padding_shape[:2] + [self.config.sliding_window] + [self.padding_shape[-1]] for i in range(self.n_layer): if ( all(hasattr(self.config, attr) for attr in ["sliding_window", "layer_types"]) diff --git a/tests/test_gpt.py b/tests/test_gpt.py new file mode 100644 index 000000000..27b423b63 --- /dev/null +++ b/tests/test_gpt.py @@ -0,0 +1,61 @@ +import torch +from transformers import AutoConfig, AutoModelForCausalLM, GptOssForCausalLM, TextStreamer + +from QEfficient import QEFFAutoModelForCausalLM +from QEfficient.utils._utils import load_hf_tokenizer +from QEfficient.utils.constants import Constants +from QEfficient.utils.run_utils import ApiRunner + +Constants.INPUT_STR=["Make sure tokens don't repeat\n\nTo make a simple cup of coffee, start by boiling water. Add one to two teaspoons of instant coffee powder to a mug. Pour the hot water over the coffee and stir well. Add sugar and milk to taste, if desired. For brewed coffee, use a French press or drip filter. Add coarsely ground coffee to the device, pour hot water over it, and let it steep for four minutes. Press or filter the coffee, then serve"] + +torch.manual_seed(42) +model_id = "openai/gpt-oss-20b" +config = AutoConfig.from_pretrained(model_id) +config.num_hidden_layers=2 + +# Remove the quantization_config attribute if it exists, to avoid MXFP4 Issues +if hasattr(config, "quantization_config"): + delattr(config, "quantization_config") + +model = GptOssForCausalLM.from_pretrained( + "/home/vbaddi/transformers/src/transformers/models/gpt_oss/new_weights", torch_dtype=torch.float32, attn_implementation="eager", config=config +) +model.eval() +model.generation_config.sample=False +tokenizer = load_hf_tokenizer(pretrained_model_name_or_path=model_id) +config = model.config +batch_size = len(Constants.INPUT_STR) + +api_runner = ApiRunner(batch_size, tokenizer, config, Constants.INPUT_STR, 97, 256) +pytorch_hf_tokens = api_runner.run_hf_model_on_pytorch(model) + + +qeff_model = QEFFAutoModelForCausalLM(model, continuous_batching=False) +# pytorch_kv_tokens = api_runner.run_kv_model_on_pytorch(qeff_model.model) + +onnx_model_path = qeff_model.export() + + +qpc_path = qeff_model.compile( + prefill_seq_len=128, + ctx_len=256, + num_cores=16, + mxfp6_matmul=False, + mxint8_kv_cache=False, + num_devices=1, + mos=1, + aic_enable_depth_first=True, + num_speculative_tokens=None, +) +print(f"qpc path is {qpc_path}") +streamer = TextStreamer(tokenizer) +exec_info = qeff_model.generate( + tokenizer, + streamer=streamer, + prompts=Constants.INPUT_STR[0], + device_ids=[0], +) + +import ipdb; ipdb.set_trace() +print(pytorch_hf_tokens) +print(exec_info) From 6c9e79cedef98269969f925f58592d71c96ba64e Mon Sep 17 00:00:00 2001 From: Onkar Chougule Date: Thu, 7 Aug 2025 19:26:57 +0530 Subject: [PATCH 05/37] added test assert Signed-off-by: Onkar Chougule --- tests/test_gpt.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_gpt.py b/tests/test_gpt.py index 27b423b63..92c17c353 100644 --- a/tests/test_gpt.py +++ b/tests/test_gpt.py @@ -56,6 +56,6 @@ device_ids=[0], ) -import ipdb; ipdb.set_trace() print(pytorch_hf_tokens) print(exec_info) +assert (exec_info.generated_ids[0][0,:159] == pytorch_hf_tokens).all() From df5dd62c5fb75dd43432586d1369744ec21822ee Mon Sep 17 00:00:00 2001 From: Vinayak Baddi Date: Fri, 8 Aug 2025 02:44:05 +0000 Subject: [PATCH 06/37] nit: update test gpt file Signed-off-by: vbaddi Signed-off-by: Onkar Chougule --- tests/test_gpt.py | 24 ++++++++++++++++++------ 1 file changed, 18 insertions(+), 6 deletions(-) diff --git a/tests/test_gpt.py b/tests/test_gpt.py index 92c17c353..8e44f2f82 100644 --- a/tests/test_gpt.py +++ b/tests/test_gpt.py @@ -1,27 +1,39 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +# SPDX-License-Identifier: BSD-3-Clause +# +# ----------------------------------------------------------------------------- + import torch -from transformers import AutoConfig, AutoModelForCausalLM, GptOssForCausalLM, TextStreamer +from transformers import AutoConfig, GptOssForCausalLM, TextStreamer from QEfficient import QEFFAutoModelForCausalLM from QEfficient.utils._utils import load_hf_tokenizer from QEfficient.utils.constants import Constants from QEfficient.utils.run_utils import ApiRunner -Constants.INPUT_STR=["Make sure tokens don't repeat\n\nTo make a simple cup of coffee, start by boiling water. Add one to two teaspoons of instant coffee powder to a mug. Pour the hot water over the coffee and stir well. Add sugar and milk to taste, if desired. For brewed coffee, use a French press or drip filter. Add coarsely ground coffee to the device, pour hot water over it, and let it steep for four minutes. Press or filter the coffee, then serve"] +Constants.INPUT_STR = [ + "Make sure tokens don't repeat\n\nTo make a simple cup of coffee, start by boiling water. Add one to two teaspoons of instant coffee powder to a mug. Pour the hot water over the coffee and stir well. Add sugar and milk to taste, if desired. For brewed coffee, use a French press or drip filter. Add coarsely ground coffee to the device, pour hot water over it, and let it steep for four minutes. Press or filter the coffee, then serve" +] torch.manual_seed(42) model_id = "openai/gpt-oss-20b" config = AutoConfig.from_pretrained(model_id) -config.num_hidden_layers=2 +config.num_hidden_layers = 2 # Remove the quantization_config attribute if it exists, to avoid MXFP4 Issues if hasattr(config, "quantization_config"): delattr(config, "quantization_config") model = GptOssForCausalLM.from_pretrained( - "/home/vbaddi/transformers/src/transformers/models/gpt_oss/new_weights", torch_dtype=torch.float32, attn_implementation="eager", config=config + "/home/vbaddi/transformers/src/transformers/models/gpt_oss/new_weights", + torch_dtype=torch.float32, + attn_implementation="eager", + config=config, ) model.eval() -model.generation_config.sample=False +model.generation_config.sample = False tokenizer = load_hf_tokenizer(pretrained_model_name_or_path=model_id) config = model.config batch_size = len(Constants.INPUT_STR) @@ -58,4 +70,4 @@ print(pytorch_hf_tokens) print(exec_info) -assert (exec_info.generated_ids[0][0,:159] == pytorch_hf_tokens).all() +assert (exec_info.generated_ids[0][0, :159] == pytorch_hf_tokens).all() From f806ac5c2625f1d8bfadedd83f941fe875a26bcc Mon Sep 17 00:00:00 2001 From: Vinayak Baddi Date: Mon, 11 Aug 2025 07:00:22 +0000 Subject: [PATCH 07/37] nit: update modeling with new decode moe forward Signed-off-by: vbaddi Signed-off-by: Onkar Chougule --- QEfficient/transformers/models/modeling_auto.py | 2 +- QEfficient/transformers/models/pytorch_transforms.py | 4 +--- 2 files changed, 2 insertions(+), 4 deletions(-) diff --git a/QEfficient/transformers/models/modeling_auto.py b/QEfficient/transformers/models/modeling_auto.py index f4a59a2c5..e669e7b9a 100644 --- a/QEfficient/transformers/models/modeling_auto.py +++ b/QEfficient/transformers/models/modeling_auto.py @@ -2353,7 +2353,7 @@ class QEFFAutoModelForCausalLM(QEFFBaseModel): Mxfp4GptOssExpertDequantizeTransform, CustomOpsTransform, KVCacheTransform, - SplitGateUpWeightsTransform, + # SplitGateUpWeightsTransform, KVCacheExternalModuleMapperTransform, ] diff --git a/QEfficient/transformers/models/pytorch_transforms.py b/QEfficient/transformers/models/pytorch_transforms.py index 21a867eb5..3ecc29d76 100644 --- a/QEfficient/transformers/models/pytorch_transforms.py +++ b/QEfficient/transformers/models/pytorch_transforms.py @@ -55,7 +55,6 @@ from transformers.models.gpt_oss.modeling_gpt_oss import ( GptOssAttention, GptOssDecoderLayer, - GptOssExperts, GptOssForCausalLM, GptOssMLP, GptOssModel, @@ -257,7 +256,6 @@ from QEfficient.transformers.models.gpt_oss.modeling_gpt_oss import ( QEffGptOssAttention, QEffGptOssDecoderLayer, - QEffGptOssExperts, QEffGptOssForCausalLM, QEffGptOssMLP, QEffGptOssModel, @@ -529,7 +527,7 @@ class KVCacheTransform(ModuleMappingTransform): GptOssModel: QEffGptOssModel, GptOssForCausalLM: QEffGptOssForCausalLM, GptOssMLP: QEffGptOssMLP, - GptOssExperts: QEffGptOssExperts, + # GptOssExperts: QEffGptOssExperts, # Granite GraniteModel: QEffGraniteModel, GraniteForCausalLM: QEffGraniteForCausalLM, From 0a6aa9cf4d635b41155982a6a48d1686e81c2c27 Mon Sep 17 00:00:00 2001 From: Vinayak Baddi Date: Wed, 20 Aug 2025 08:50:05 +0000 Subject: [PATCH 08/37] nit: seperate gate, up projections for MoE Signed-off-by: vbaddi Signed-off-by: Onkar Chougule --- QEfficient/transformers/models/modeling_auto.py | 2 +- QEfficient/transformers/models/pytorch_transforms.py | 4 +++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/QEfficient/transformers/models/modeling_auto.py b/QEfficient/transformers/models/modeling_auto.py index e669e7b9a..f4a59a2c5 100644 --- a/QEfficient/transformers/models/modeling_auto.py +++ b/QEfficient/transformers/models/modeling_auto.py @@ -2353,7 +2353,7 @@ class QEFFAutoModelForCausalLM(QEFFBaseModel): Mxfp4GptOssExpertDequantizeTransform, CustomOpsTransform, KVCacheTransform, - # SplitGateUpWeightsTransform, + SplitGateUpWeightsTransform, KVCacheExternalModuleMapperTransform, ] diff --git a/QEfficient/transformers/models/pytorch_transforms.py b/QEfficient/transformers/models/pytorch_transforms.py index 3ecc29d76..21a867eb5 100644 --- a/QEfficient/transformers/models/pytorch_transforms.py +++ b/QEfficient/transformers/models/pytorch_transforms.py @@ -55,6 +55,7 @@ from transformers.models.gpt_oss.modeling_gpt_oss import ( GptOssAttention, GptOssDecoderLayer, + GptOssExperts, GptOssForCausalLM, GptOssMLP, GptOssModel, @@ -256,6 +257,7 @@ from QEfficient.transformers.models.gpt_oss.modeling_gpt_oss import ( QEffGptOssAttention, QEffGptOssDecoderLayer, + QEffGptOssExperts, QEffGptOssForCausalLM, QEffGptOssMLP, QEffGptOssModel, @@ -527,7 +529,7 @@ class KVCacheTransform(ModuleMappingTransform): GptOssModel: QEffGptOssModel, GptOssForCausalLM: QEffGptOssForCausalLM, GptOssMLP: QEffGptOssMLP, - # GptOssExperts: QEffGptOssExperts, + GptOssExperts: QEffGptOssExperts, # Granite GraniteModel: QEffGraniteModel, GraniteForCausalLM: QEffGraniteForCausalLM, From 7731691e0fcd0638e316f7a3911c3bfbdfa79179 Mon Sep 17 00:00:00 2001 From: Vinayak Baddi Date: Wed, 15 Oct 2025 08:57:42 +0000 Subject: [PATCH 09/37] nit: remove test file and add sample test in config Signed-off-by: vbaddi Signed-off-by: Onkar Chougule --- tests/test_gpt.py | 73 ----------------------------------------------- 1 file changed, 73 deletions(-) delete mode 100644 tests/test_gpt.py diff --git a/tests/test_gpt.py b/tests/test_gpt.py deleted file mode 100644 index 8e44f2f82..000000000 --- a/tests/test_gpt.py +++ /dev/null @@ -1,73 +0,0 @@ -# ----------------------------------------------------------------------------- -# -# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. -# SPDX-License-Identifier: BSD-3-Clause -# -# ----------------------------------------------------------------------------- - -import torch -from transformers import AutoConfig, GptOssForCausalLM, TextStreamer - -from QEfficient import QEFFAutoModelForCausalLM -from QEfficient.utils._utils import load_hf_tokenizer -from QEfficient.utils.constants import Constants -from QEfficient.utils.run_utils import ApiRunner - -Constants.INPUT_STR = [ - "Make sure tokens don't repeat\n\nTo make a simple cup of coffee, start by boiling water. Add one to two teaspoons of instant coffee powder to a mug. Pour the hot water over the coffee and stir well. Add sugar and milk to taste, if desired. For brewed coffee, use a French press or drip filter. Add coarsely ground coffee to the device, pour hot water over it, and let it steep for four minutes. Press or filter the coffee, then serve" -] - -torch.manual_seed(42) -model_id = "openai/gpt-oss-20b" -config = AutoConfig.from_pretrained(model_id) -config.num_hidden_layers = 2 - -# Remove the quantization_config attribute if it exists, to avoid MXFP4 Issues -if hasattr(config, "quantization_config"): - delattr(config, "quantization_config") - -model = GptOssForCausalLM.from_pretrained( - "/home/vbaddi/transformers/src/transformers/models/gpt_oss/new_weights", - torch_dtype=torch.float32, - attn_implementation="eager", - config=config, -) -model.eval() -model.generation_config.sample = False -tokenizer = load_hf_tokenizer(pretrained_model_name_or_path=model_id) -config = model.config -batch_size = len(Constants.INPUT_STR) - -api_runner = ApiRunner(batch_size, tokenizer, config, Constants.INPUT_STR, 97, 256) -pytorch_hf_tokens = api_runner.run_hf_model_on_pytorch(model) - - -qeff_model = QEFFAutoModelForCausalLM(model, continuous_batching=False) -# pytorch_kv_tokens = api_runner.run_kv_model_on_pytorch(qeff_model.model) - -onnx_model_path = qeff_model.export() - - -qpc_path = qeff_model.compile( - prefill_seq_len=128, - ctx_len=256, - num_cores=16, - mxfp6_matmul=False, - mxint8_kv_cache=False, - num_devices=1, - mos=1, - aic_enable_depth_first=True, - num_speculative_tokens=None, -) -print(f"qpc path is {qpc_path}") -streamer = TextStreamer(tokenizer) -exec_info = qeff_model.generate( - tokenizer, - streamer=streamer, - prompts=Constants.INPUT_STR[0], - device_ids=[0], -) - -print(pytorch_hf_tokens) -print(exec_info) -assert (exec_info.generated_ids[0][0, :159] == pytorch_hf_tokens).all() From 15ebe3915f159f745b283118872a58c275d730b9 Mon Sep 17 00:00:00 2001 From: Mamta Singh Date: Mon, 3 Nov 2025 11:52:29 +0000 Subject: [PATCH 10/37] Enable CB for GptOssModel Signed-off-by: Mamta Singh Signed-off-by: Onkar Chougule --- QEfficient/utils/generate_inputs.py | 1 + 1 file changed, 1 insertion(+) diff --git a/QEfficient/utils/generate_inputs.py b/QEfficient/utils/generate_inputs.py index cb2a68fa9..a55284c3b 100644 --- a/QEfficient/utils/generate_inputs.py +++ b/QEfficient/utils/generate_inputs.py @@ -174,6 +174,7 @@ def prepare_ort_inputs(self): inputs["past_key." + str(i)] = np.zeros((cache_shape), dtype=np.float32) inputs["past_value." + str(i)] = np.zeros((cache_shape), dtype=np.float32) else: + sliding_padding_shape = self.padding_shape[:2] + [self.config.sliding_window] + [self.padding_shape[-1]] for i in range(self.n_layer): if ( all(hasattr(self.config, attr) for attr in ["sliding_window", "layer_types"]) From 52f64b4520c704be2b0c3fe7484a623c41d0b101 Mon Sep 17 00:00:00 2001 From: Mamta Singh Date: Tue, 4 Nov 2025 06:33:47 +0000 Subject: [PATCH 11/37] Fix tests Signed-off-by: Mamta Singh --- QEfficient/utils/generate_inputs.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/QEfficient/utils/generate_inputs.py b/QEfficient/utils/generate_inputs.py index a55284c3b..b99f9ea9c 100644 --- a/QEfficient/utils/generate_inputs.py +++ b/QEfficient/utils/generate_inputs.py @@ -92,12 +92,15 @@ def prepare_pytorch_inputs(self): inputs["batch_index"] = torch.arange(self.full_batch_size).view(-1, 1) past_key_values = [] - sliding_padding_shape = self.padding_shape[:2] + [self.config.sliding_window] + [self.padding_shape[-1]] for i in range(self.n_layer): +<<<<<<< HEAD if ( all(hasattr(self.config, attr) for attr in ["sliding_window", "layer_types"]) and self.config.layer_types[i] == "sliding_attention" ): +======= + if hasattr(self.config, "sliding_window") and self.config.layer_types[i] == "sliding_attention": +>>>>>>> b1ed627 (Fix tests) pad_shape = self.padding_shape[:2] + [self.config.sliding_window] + [self.padding_shape[-1]] else: pad_shape = self.padding_shape @@ -174,7 +177,6 @@ def prepare_ort_inputs(self): inputs["past_key." + str(i)] = np.zeros((cache_shape), dtype=np.float32) inputs["past_value." + str(i)] = np.zeros((cache_shape), dtype=np.float32) else: - sliding_padding_shape = self.padding_shape[:2] + [self.config.sliding_window] + [self.padding_shape[-1]] for i in range(self.n_layer): if ( all(hasattr(self.config, attr) for attr in ["sliding_window", "layer_types"]) From 79cbae908eb54b0c5e4b7da27bc891e5f32a2be2 Mon Sep 17 00:00:00 2001 From: Mamta Singh Date: Tue, 4 Nov 2025 09:41:57 +0000 Subject: [PATCH 12/37] Address review comments Signed-off-by: Mamta Singh --- QEfficient/transformers/models/modeling_auto.py | 5 ----- QEfficient/utils/generate_inputs.py | 4 ---- pyproject.toml | 4 ++-- 3 files changed, 2 insertions(+), 11 deletions(-) diff --git a/QEfficient/transformers/models/modeling_auto.py b/QEfficient/transformers/models/modeling_auto.py index f4a59a2c5..8edc1f3f0 100644 --- a/QEfficient/transformers/models/modeling_auto.py +++ b/QEfficient/transformers/models/modeling_auto.py @@ -3082,11 +3082,6 @@ def compile( for kv in ["key", "value"]: custom_io[f"past_{kv}.{i}{suffix}"] = kv_cache_dtype - # HACK for now - if self.model.config.model_type == "gpt_oss": - for spec in specializations: - spec.update({"sliding_window": 128}) - qpc_path = self._compile( onnx_path=onnx_path, compile_dir=compile_dir, diff --git a/QEfficient/utils/generate_inputs.py b/QEfficient/utils/generate_inputs.py index b99f9ea9c..95474acfd 100644 --- a/QEfficient/utils/generate_inputs.py +++ b/QEfficient/utils/generate_inputs.py @@ -93,14 +93,10 @@ def prepare_pytorch_inputs(self): past_key_values = [] for i in range(self.n_layer): -<<<<<<< HEAD if ( all(hasattr(self.config, attr) for attr in ["sliding_window", "layer_types"]) and self.config.layer_types[i] == "sliding_attention" ): -======= - if hasattr(self.config, "sliding_window") and self.config.layer_types[i] == "sliding_attention": ->>>>>>> b1ed627 (Fix tests) pad_shape = self.padding_shape[:2] + [self.config.sliding_window] + [self.padding_shape[-1]] else: pad_shape = self.padding_shape diff --git a/pyproject.toml b/pyproject.toml index cc38f4bf8..8e179ab4a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -22,8 +22,8 @@ dependencies = [ "transformers==4.55.0", "huggingface-hub==0.34.0", "hf_transfer==0.1.9", - "peft", - "datasets", + "peft==0.13.2", + "datasets==2.20.0", "fsspec==2023.6.0", "multidict==6.0.4", "urllib3<2", From 3e2a261459a25e47324dd77a4759c9c908486261 Mon Sep 17 00:00:00 2001 From: Onkar Chougule Date: Tue, 4 Nov 2025 19:28:03 +0000 Subject: [PATCH 13/37] prefill only changes for gpt-oss Signed-off-by: Onkar Chougule --- QEfficient/__init__.py | 23 +- QEfficient/base/modeling_qeff.py | 25 +- QEfficient/transformers/cache_utils.py | 31 +++ QEfficient/transformers/modeling_utils.py | 3 + .../models/gpt_oss/modeling_gpt_oss.py | 238 +++++++++++++++++- .../transformers/models/modeling_auto.py | 84 ++++--- .../transformers/models/pytorch_transforms.py | 11 + QEfficient/utils/_utils.py | 1 + QEfficient/utils/hash_utils.py | 1 + examples/gpt_oss_disagg_mode.py | 47 ++++ 10 files changed, 417 insertions(+), 47 deletions(-) create mode 100644 examples/gpt_oss_disagg_mode.py diff --git a/QEfficient/__init__.py b/QEfficient/__init__.py index 7f63b34ca..0360f2753 100644 --- a/QEfficient/__init__.py +++ b/QEfficient/__init__.py @@ -6,7 +6,17 @@ # ----------------------------------------------------------------------------- import os -import warnings + +# ----------------------------------------------------------------------------- # +# For faster downloads via hf_transfer +# This code is put above import statements as this needs to be executed before +# hf_transfer is imported (will happen on line 15 via leading imports) +os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1" +# DO NOT ADD ANY CODE ABOVE THIS LINE +# Please contact maintainers if you must edit this file above this line. +# ----------------------------------------------------------------------------- # +# Placeholder for all non-transformer models registered in QEfficient +import warnings # noqa: I001 import QEfficient.utils.model_registery # noqa: F401 from QEfficient.base import ( @@ -25,6 +35,10 @@ from QEfficient.utils import custom_format_warning from QEfficient.utils.logging_utils import logger +# custom warning for the better logging experience +warnings.formatwarning = custom_format_warning + + # Users can use QEfficient.export for exporting models to ONNX export = qualcomm_efficient_converter __all__ = [ @@ -40,14 +54,7 @@ "QEFFAutoModelForSpeechSeq2Seq", "QEFFCommonLoader", ] -# For faster downloads via hf_transfer -# This code is put above import statements as this needs to be executed before -# hf_transfer is imported (will happen on line 15 via leading imports) -os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1" -# Placeholder for all non-transformer models registered in QEfficient -# custom warning for the better logging experience -warnings.formatwarning = custom_format_warning # Conditionally import QAIC-related modules if the SDK is installed __version__ = "0.0.1.dev0" diff --git a/QEfficient/base/modeling_qeff.py b/QEfficient/base/modeling_qeff.py index ef7e83adf..a0b5d1cdb 100644 --- a/QEfficient/base/modeling_qeff.py +++ b/QEfficient/base/modeling_qeff.py @@ -66,6 +66,7 @@ def __init__(self, model: torch.nn.Module, **kwargs) -> None: super().__init__() self.model = model self.hash_params = create_model_params(self, **kwargs) + self.prefill_onnx_path: Optional[str] = None self.onnx_path: Optional[str] = None self.qpc_path: Optional[str] = None self.qpc_session: Optional[QAICInferenceSession] = None @@ -189,6 +190,7 @@ def _export( export_dir: Optional[str] = None, offload_pt_weights: bool = True, use_onnx_subfunctions: bool = False, + prefill_only: Optional[bool] = False, ) -> str: """ Export the PyTorch model to ONNX and apply ONNX transforms @@ -217,7 +219,10 @@ def _export( # Return early if ONNX already exists if onnx_path.is_file(): - self.onnx_path = onnx_path + if prefill_only: + self.prefill_onnx_path = onnx_path + else: + self.onnx_path = onnx_path return onnx_path # check if the model is in meta state or weights are offloaded @@ -315,7 +320,10 @@ def _export( self._onnx_transforms.remove(CustomOpTransform) self._onnx_transforms.remove(RenameFunctionOutputsTransform) - self.onnx_path = onnx_path + if prefill_only: + self.prefill_onnx_path = onnx_path + else: + self.onnx_path = onnx_path return onnx_path @dump_qconfig @@ -332,6 +340,8 @@ def _compile( enable_qnn: Optional[bool] = False, qnn_config: Optional[str] = None, use_onnx_subfunctions: bool = False, + prefill_only: Optional[str] = None, + offload_pt_weights: Optional[bool] = True, **compiler_options, ) -> str: """ @@ -357,11 +367,17 @@ def _compile( For QNN Compilation path, when enable_qnn is set to True, any parameter passed in compiler_options will be ignored. """ + kwargs = {"offload_pt_weights": offload_pt_weights, + "use_onnx_subfunctions": use_onnx_subfunctions} + if prefill_only and self.prefill_onnx_path is None: + kwargs.update({"prefill_only": prefill_only, "prefill_seq_len": specializations[0].get("seq_len")}) + self.export(**kwargs) + onnx_path = Path(onnx_path or self.prefill_onnx_path) if onnx_path is None and self.onnx_path is None: - self.export(use_onnx_subfunctions=use_onnx_subfunctions) + self.export(**kwargs) + onnx_path = Path(onnx_path or self.onnx_path) - onnx_path = Path(onnx_path or self.onnx_path) compile_dir = Path(compile_dir or onnx_path.parent) qpc_path = compile_dir / "qpc" if not onnx_path.is_file(): @@ -423,6 +439,7 @@ def _compile( "mdp_ts_num_devices": mdp_ts_num_devices, "mdp_ts_json": mdp_ts_json, "num_speculative_tokens": num_speculative_tokens, + "prefill_only": prefill_only, } compile_hash = hash_dict_params(compile_hash_params) diff --git a/QEfficient/transformers/cache_utils.py b/QEfficient/transformers/cache_utils.py index 62cc71a4c..90cbdb2dd 100644 --- a/QEfficient/transformers/cache_utils.py +++ b/QEfficient/transformers/cache_utils.py @@ -681,6 +681,37 @@ def to_legacy_cache(self) -> Tuple[Tuple[torch.Tensor], Tuple[torch.Tensor]]: legacy_cache += ((self.key_cache[layer_idx], self.value_cache[layer_idx]),) return legacy_cache + def write_only( + self, + key_states: torch.Tensor, + value_states: torch.Tensor, + layer_idx: int, + cache_kwargs: Optional[Dict[str, Any]] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + if len(self.key_cache) <= layer_idx: + self.key_cache.append(key_states) + self.value_cache.append(value_states) + k_out, v_out = key_states, value_states + else: + position_ids = cache_kwargs.get("position_ids") + is_sliding_layer = cache_kwargs.get("is_sliding") + _, _, ctx_len, _ = self.key_cache[layer_idx].shape + if is_sliding_layer: + kv_position_ids = torch.arange(ctx_len, dtype=torch.int64).reshape(1, -1) + self.key_cache[layer_idx] = CtxScatterFunc.apply(self.key_cache[layer_idx], kv_position_ids, key_states) + self.value_cache[layer_idx] = CtxScatterFunc.apply( + self.value_cache[layer_idx], kv_position_ids, value_states + ) + else: + kv_position_ids = position_ids + + self.key_cache[layer_idx] = CtxScatterFunc.apply(self.key_cache[layer_idx], kv_position_ids, key_states) + self.value_cache[layer_idx] = CtxScatterFunc.apply( + self.value_cache[layer_idx], kv_position_ids, value_states + ) + k_out, v_out = self.key_cache[layer_idx], self.value_cache[layer_idx] + return k_out, v_out + def update( self, key_states: torch.Tensor, diff --git a/QEfficient/transformers/modeling_utils.py b/QEfficient/transformers/modeling_utils.py index 5337b44f5..47059d8dc 100644 --- a/QEfficient/transformers/modeling_utils.py +++ b/QEfficient/transformers/modeling_utils.py @@ -188,6 +188,9 @@ # This is for supporting different seq_len for different layers for Sliding window attn, chunked attn etc. DYNAMIC_SEQ_LEN_SUPPORTED_MODEL_ARCH = {"gemma3", "llama4", "gemma3_text", "llama4_text"} +# This is for supporting different modelling classes specially written for prefill-only model +SPECIALIZED_PREFILL_ONLY_MODEL_ARCH = {"gpt_oss"} + # Define a transformers layers to QEff layers dictionary # While onboarding new models make sure to add the new layer maps to this dictionary. TransformersToQEffModulesDict: Dict[Type[nn.Module], Type[nn.Module]] = { diff --git a/QEfficient/transformers/models/gpt_oss/modeling_gpt_oss.py b/QEfficient/transformers/models/gpt_oss/modeling_gpt_oss.py index 84552aff4..2f88613bc 100644 --- a/QEfficient/transformers/models/gpt_oss/modeling_gpt_oss.py +++ b/QEfficient/transformers/models/gpt_oss/modeling_gpt_oss.py @@ -4,6 +4,7 @@ # SPDX-License-Identifier: BSD-3-Clause # # ----------------------------------------------------------------------------- +import os from typing import Callable, Optional, Union import torch @@ -32,6 +33,7 @@ from QEfficient.transformers.modeling_attn_mask_utils import _create_causal_mask from QEfficient.utils import constants from QEfficient.utils.constants import MIN_MASKED_ATTENTION_VALUE +from QEfficient.utils.logging_utils import logger class QEffGptOssExperts(GptOssExperts): @@ -42,8 +44,8 @@ def __qeff_init__(self): self.up_proj_bias = nn.Parameter(torch.empty(self.num_experts, self.expert_dim)) -class QEffGptOssMLP(GptOssMLP): - def alt_forward(self, hidden: torch.Tensor): +class QEffPrefillOnlyGptOssMLP(GptOssMLP): + def forward(self, hidden: torch.Tensor): B, S, H = hidden.shape T = B * S hidden = hidden.view(T, H) @@ -95,6 +97,8 @@ def alt_forward(self, hidden: torch.Tensor): # original shape [B, S, H] return expert_out.view(B, S, H), router_logits + +class QEffGptOssMLP(GptOssMLP): # ------------------- Gather based, weights as activation approach --------------- def forward_weights_as_activation(self, hidden_states): bs, seq_len, _ = hidden_states.shape @@ -404,6 +408,137 @@ def eager_attention_forward( return attn_output, attn_weights +def eager_attention_forward_blocked( + module: nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: Optional[torch.Tensor], + scaling: float, + **kwargs, +): + softmax_count = 0 + key_states = repeat_kv(key, module.num_key_value_groups) + value_states = repeat_kv(value, module.num_key_value_groups) + + BS, NH, CL, DH = query.shape + target_blocks = int(os.environ.get("NUM_BLOCKS")) + block_positions = [] + for j in range(target_blocks): + block_positions.append(j * (CL // target_blocks)) + + print(f"CL={CL}, target_blocks={target_blocks}") + + block_count = 0 + outs = [] + for block_idx in range(target_blocks): + block_count += 1 + qi = block_positions[block_idx] + + # Calculate block size (last block should be handled with remainder) + if block_idx == target_blocks - 1: + real_q_len = CL - qi + else: + real_q_len = block_positions[block_idx + 1] - qi + + q_block = query[:, :, qi : qi + real_q_len, :] + scores = torch.matmul(q_block, key_states.transpose(2, 3)) * scaling + attn_mask_block = attention_mask[:, :, qi : qi + real_q_len, :] + curr_attn_weights = torch.where( + attn_mask_block, torch.tensor(MIN_MASKED_ATTENTION_VALUE, dtype=torch.float32), scores + ) + sinks = module.sinks.reshape(1, -1, 1, 1).expand( + curr_attn_weights.shape[0], -1, curr_attn_weights.shape[-2], -1 + ) + combined_logits = torch.cat([curr_attn_weights, sinks], dim=-1) + combined_logits = combined_logits - combined_logits.max(dim=-1, keepdim=True).values + curr_attn_weights = nn.functional.softmax(combined_logits, dim=-1, dtype=torch.float32) + curr_attn_weights = curr_attn_weights[..., :-1] + out_block = torch.matmul(curr_attn_weights, value_states) + outs.append(out_block) + output = torch.cat(outs, dim=2) + + print(f"Completed {block_count} blocks, {softmax_count} softmax operations") + output = output.view(BS, NH, CL, DH).transpose(1, 2).contiguous() + return output, output + + +class QEffPrefillOnlyGptOssAttention(GptOssAttention): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __qeff_init__(self): + self.rotary_emb = QEffGptOssRotaryEmbedding(config=self.config) + + def forward( + self, + hidden_states: torch.Tensor, + position_embeddings: tuple[torch.Tensor, torch.Tensor], + attention_mask: Optional[torch.Tensor], + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + batch_index: Optional[torch.LongTensor] = None, + cache_position: Optional[torch.LongTensor] = None, + sliding_mask=None, + **kwargs: Unpack[TransformersKwargs], + ) -> tuple[torch.Tensor, torch.Tensor]: + input_shape = hidden_states.shape[:-1] + hidden_shape = (*input_shape, -1, self.head_dim) + query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2) + hidden_shape = (*input_shape, -1, self.head_dim) + key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2) + value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) + cos, sin = self.rotary_emb(value_states, seq_len=32 * 1024) + query_states, key_states = qeff_apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) + + if past_key_value is not None: + # sin and cos are specific to RoPE models; cache_position needed for the static cache + cache_kwargs = { + "sin": sin, + "cos": cos, + "batch_index": batch_index, + "position_ids": position_ids, + "config": self.config, + "is_sliding": self.sliding_window is not None, + "sliding_window": past_key_value.sliding_window_len, + } + if self.sliding_window is not None: + sliding_window_len = past_key_value.sliding_window_len + short_read_idx = torch.arange(sliding_window_len) + read_idx = short_read_idx + torch.where( + position_ids.max() > sliding_window_len - 1, position_ids.max() - sliding_window_len + 1, 0 + ) + # This is a trick to export with NUM_BLOCKS position_ids.max(), 0, read_idx) + k_cache = key_states[:, :, read_idx, :] + v_cache = value_states[:, :, read_idx, :] + else: + k_cache, v_cache = key_states, value_states + _, _ = past_key_value.write_only(k_cache, v_cache, self.layer_idx, cache_kwargs) + + if self.sliding_window is not None: + attention_mask = sliding_mask + else: + attention_mask = attention_mask + + attention_interface: Callable = eager_attention_forward_blocked + attn_output, attn_weights = attention_interface( + self, + query_states, + key_states, + value_states, + attention_mask, + dropout=0.0 if not self.training else self.attention_dropout, + scaling=self.scaling, + sliding_window=self.sliding_window, + s_aux=self.sinks, # diff with Llama + **kwargs, + ) + + attn_output = attn_output.reshape(*input_shape, -1).contiguous() + attn_output = self.o_proj(attn_output) + return attn_output, attn_weights, past_key_value + + class QEffGptOssAttention(GptOssAttention): """Multi-headed attention from 'Attention Is All You Need' paper""" @@ -511,7 +646,6 @@ def forward( residual = hidden_states hidden_states = self.post_attention_layernorm(hidden_states) hidden_states, _ = self.mlp(hidden_states) # diff with llama: router scores - # alth, _ = self.mlp.alt_forward(hidden_states) hidden_states = hidden_states.reshape(residual.shape) hidden_states = residual + hidden_states outputs = (hidden_states,) @@ -525,6 +659,98 @@ def forward( return outputs +class QEffPrefillOnlyGptOssModel(GptOssModel): + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + batch_index: Optional[torch.LongTensor] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + return_dict: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + **kwargs: Unpack[TransformersKwargs], + ) -> MoeModelOutputWithPast: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if (input_ids is None) ^ (inputs_embeds is not None): + raise ValueError("You must specify exactly one of input_ids or inputs_embeds") + + return_legacy_cache = False + if use_cache and not isinstance(past_key_values, Cache): + return_legacy_cache = True + past_key_values = QEffHybridCacheForGPTOSS.from_legacy_cache(self.config, past_key_values) + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + if cache_position is None: + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + cache_position = torch.arange( + past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device + ) + if position_ids is None: + position_ids = cache_position.unsqueeze(0) + + # target_length = attention_mask.shape[-1] if isinstance(attention_mask, torch.Tensor) else past_seen_tokens + causal_mask = _create_causal_mask(position_ids=position_ids, target_length=past_key_values.max_cache_len) + sliding_mask = _create_causal_mask( + position_ids=position_ids, + target_length=past_key_values.max_cache_len, + sliding_window=past_key_values.sliding_window_len, + ) + + hidden_states = inputs_embeds + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + + for decoder_layer in self.layers: + if output_hidden_states: + all_hidden_states += (hidden_states,) + + layer_outputs = decoder_layer( + hidden_states, + attention_mask=causal_mask, + position_ids=position_ids, + past_key_value=past_key_values, + batch_index=batch_index, + use_cache=use_cache, + output_attentions=output_attentions, + cache_position=cache_position, + sliding_mask=sliding_mask, + **kwargs, + ) + hidden_states = layer_outputs[0] + + if output_attentions: + all_self_attns += (layer_outputs[1],) + + hidden_states = self.norm(hidden_states) + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + if return_legacy_cache: + past_key_values = past_key_values.to_legacy_cache() + + return MoeModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=past_key_values if use_cache else None, + ) + + class QEffGptOssModel(GptOssModel): def forward( self, @@ -724,9 +950,15 @@ def get_specializations( batch_size: int, prefill_seq_len: int, ctx_len: int, + **kwargs, ): batch_size = batch_size if batch_size else 1 prefill_seq_len = prefill_seq_len if prefill_seq_len else constants.PROMPT_LEN + if kwargs.get("prefill_only") and ctx_len != prefill_seq_len: + ctx_len = prefill_seq_len + logger.warning( + f"overriding ctx_len={prefill_seq_len}, currently we don't support ctx_len different than prefill_seq_len for prefill_only model" + ) ctx_len = ctx_len if ctx_len else constants.CTX_LEN specializations = [ diff --git a/QEfficient/transformers/models/modeling_auto.py b/QEfficient/transformers/models/modeling_auto.py index 8edc1f3f0..e56486058 100644 --- a/QEfficient/transformers/models/modeling_auto.py +++ b/QEfficient/transformers/models/modeling_auto.py @@ -5,6 +5,7 @@ # # ---------------------------------------------------------------------------- +import os import warnings from pathlib import Path from time import perf_counter @@ -37,13 +38,17 @@ get_compilation_dims, ) from QEfficient.generation.vlm_generation import VisionLanguageGeneration -from QEfficient.transformers.modeling_utils import DYNAMIC_SEQ_LEN_SUPPORTED_MODEL_ARCH +from QEfficient.transformers.modeling_utils import ( + DYNAMIC_SEQ_LEN_SUPPORTED_MODEL_ARCH, + SPECIALIZED_PREFILL_ONLY_MODEL_ARCH, +) from QEfficient.transformers.models.pytorch_transforms import ( BlockedKVAttentionTransform, CustomOpsTransform, KVCacheExternalModuleMapperTransform, KVCacheTransform, PoolingTransform, + PrefillOnlyTransform, SamplerTransform, SpDTransform, VlmKVOffloadTransform, @@ -316,7 +321,7 @@ def get_model_config(self) -> dict: """ return self.model.config.__dict__ - def export(self, export_dir: Optional[str] = None, use_onnx_subfunctions: bool = False) -> str: + def export(self, export_dir: Optional[str] = None, **kwargs) -> str: """ Export the model to ONNX format using ``torch.onnx.export``. @@ -353,7 +358,7 @@ def export(self, export_dir: Optional[str] = None, use_onnx_subfunctions: bool = output_names, dynamic_axes, export_dir=export_dir, - use_onnx_subfunctions=use_onnx_subfunctions, + use_onnx_subfunctions=kwargs.get("use_onnx_subfunctions", False), ) def compile( @@ -603,15 +608,7 @@ def __init__(self, model: nn.modules, **kwargs): self.model = model.get_qeff_vision_encoder() self.hash_params["qeff_auto_class"] = self.__class__.__name__ - def export( - self, - inputs, - output_names, - dynamic_axes, - export_dir=None, - offload_pt_weights=True, - use_onnx_subfunctions: bool = False, - ): + def export(self, inputs, output_names, dynamic_axes, export_dir=None, offload_pt_weights=True, **kwargs): """ Exports the vision encoder component to ONNX format. @@ -641,7 +638,7 @@ def export( dynamic_axes, export_dir=export_dir, offload_pt_weights=offload_pt_weights, - use_onnx_subfunctions=use_onnx_subfunctions, + use_onnx_subfunctions=kwargs.get("use_onnx_subfunctions", False), ) def compile( @@ -771,15 +768,7 @@ def __init__(self, model, qaic_config, **kwargs): if self.model.qaic_config is not None and self.model.qaic_config.get("num_kv_blocks", None) is not None: BlockedKVAttentionTransform.apply(self.model, num_kv_blocks=self.model.qaic_config.get("num_kv_blocks")) - def export( - self, - inputs, - output_names, - dynamic_axes, - export_dir=None, - offload_pt_weights=True, - use_onnx_subfunctions: bool = False, - ): + def export(self, inputs, output_names, dynamic_axes, export_dir=None, offload_pt_weights=True, **kwargs): """ Exports the language decoder component to ONNX format. @@ -809,7 +798,7 @@ def export( dynamic_axes, export_dir=export_dir, offload_pt_weights=offload_pt_weights, - use_onnx_subfunctions=use_onnx_subfunctions, + use_onnx_subfunctions=kwargs.get("use_onnx_subfunctions", False), ) def compile( @@ -2555,7 +2544,14 @@ def get_model_config(self) -> dict: """ return self.model.config.__dict__ - def export(self, export_dir: Optional[str] = None, use_onnx_subfunctions: bool = False, **kwargs) -> str: + def export( + self, + export_dir: Optional[str] = None, + prefill_only: Optional[bool] = False, + prefill_seq_len: Optional[int] = None, + offload_pt_weights: Optional[bool] = True, + **kwargs, + ) -> str: """ Export the model to ONNX format using ``torch.onnx.export``. @@ -2575,8 +2571,25 @@ def export(self, export_dir: Optional[str] = None, use_onnx_subfunctions: bool = str Path to the generated ONNX graph file. """ + if prefill_only: + block_size = os.environ.get("BLOCK_SIZE", None) + if block_size is None: + block_size = 128 + logger.warning( + "Setting BLOCK_SIZE=128 for prefill_only model, please set ENV variable `BLOCK_SIZE` to override" + ) + if prefill_seq_len is None or prefill_seq_len % block_size != 0: + raise ValueError( + f"When prefill_only=True, 'prefill_seq_len' must be explicitly set and divisible by block_size={block_size}. " + f"Received: prefill_seq_len={prefill_seq_len}" + ) + + os.environ["NUM_BLOCKS"] = str(prefill_seq_len // block_size) + if self.model.config.model_type in SPECIALIZED_PREFILL_ONLY_MODEL_ARCH: + self.model, tf = PrefillOnlyTransform.apply(self.model) + bs: int = constants.ONNX_EXPORT_EXAMPLE_BATCH_SIZE - seq_len: int = constants.ONNX_EXPORT_EXAMPLE_SEQ_LEN + seq_len: int = constants.ONNX_EXPORT_EXAMPLE_SEQ_LEN if not prefill_only else prefill_seq_len // block_size fbs: int = constants.ONNX_EXPORT_EXAMPLE_FBS kv_cache_shape = get_padding_shape_from_config( self.model.config, fbs if self.continuous_batching else bs, seq_len @@ -2659,14 +2672,14 @@ def export(self, export_dir: Optional[str] = None, use_onnx_subfunctions: bool = output_names=output_names, dynamic_axes=dynamic_axes, ) - return self._export( example_inputs, output_names, dynamic_axes, export_dir=export_dir, - use_onnx_subfunctions=use_onnx_subfunctions, + use_onnx_subfunctions=kwargs.get("use_onnx_subfunctions", False), offload_pt_weights=kwargs.get("offload_pt_weights", True), + prefill_only=prefill_only, ) def get_sampling_inputs_and_outputs( @@ -2756,6 +2769,7 @@ def build_prefill_specialization( batch_size: int = 1, kv_cache_batch_size: Optional[int] = None, full_batch_size: Optional[int] = None, + **kwargs, ): """ Builds a dictionary representing a compilation specialization for the prefill phase. @@ -2783,6 +2797,7 @@ def build_prefill_specialization( batch_size=1 if self.continuous_batching else batch_size, prefill_seq_len=prefill_seq_len, ctx_len=ctx_len, + **kwargs, )[0] else: spec = { @@ -2880,6 +2895,7 @@ def compile( num_speculative_tokens: Optional[int] = None, prefill_only: Optional[bool] = None, use_onnx_subfunctions: bool = False, + offload_pt_weights: Optional[bool] = True, **compiler_options, ) -> str: """ @@ -3014,6 +3030,9 @@ def compile( ): raise ValueError("Currently, sampler does not support `num_speculative_tokens` > 0.") + if kv_cache_batch_size and prefill_only is not None and prefill_only: + logger.warning("kv_cache_batch_size will be ignored as prefill_only is set to True") + # Infer kv_cache_batch_size if not provided kv_cache_batch_size = kv_cache_batch_size or full_batch_size or batch_size @@ -3081,7 +3100,6 @@ def compile( for i in range(self.num_layers): for kv in ["key", "value"]: custom_io[f"past_{kv}.{i}{suffix}"] = kv_cache_dtype - qpc_path = self._compile( onnx_path=onnx_path, compile_dir=compile_dir, @@ -3096,6 +3114,8 @@ def compile( aic_num_cores=num_cores, mxint8_kv_cache=mxint8_kv_cache, use_onnx_subfunctions=use_onnx_subfunctions, + prefill_only=prefill_only, + offload_pt_weights=offload_pt_weights, **compiler_options, ) @@ -3287,7 +3307,7 @@ def get_model_config(self) -> dict: """ return self.model.config.__dict__ - def export(self, export_dir: Optional[str] = None, use_onnx_subfunctions: bool = False) -> str: + def export(self, export_dir: Optional[str] = None, **kwargs) -> str: """ Export the model to ONNX format using ``torch.onnx.export``. @@ -3315,7 +3335,7 @@ def export(self, export_dir: Optional[str] = None, use_onnx_subfunctions: bool = output_names, dynamic_axes, export_dir=export_dir, - use_onnx_subfunctions=use_onnx_subfunctions, + use_onnx_subfunctions=kwargs.get("use_onnx_subfunctions", False), ) def compile( @@ -3663,7 +3683,7 @@ def from_pretrained(cls, pretrained_model_name_or_path, pooling=None, *args, **k def get_model_config(self) -> dict: return self.model.config.__dict__ - def export(self, export_dir: Optional[str] = None, use_onnx_subfunctions: bool = False) -> str: + def export(self, export_dir: Optional[str] = None, **kwargs) -> str: """ Exports the model to ``ONNX`` format using ``torch.onnx.export``. @@ -3691,7 +3711,7 @@ def export(self, export_dir: Optional[str] = None, use_onnx_subfunctions: bool = output_names, dynamic_axes, export_dir=export_dir, - use_onnx_subfunctions=use_onnx_subfunctions, + use_onnx_subfunctions=kwargs.get("use_onnx_subfunctions", False), ) def compile( diff --git a/QEfficient/transformers/models/pytorch_transforms.py b/QEfficient/transformers/models/pytorch_transforms.py index 21a867eb5..1d7324786 100644 --- a/QEfficient/transformers/models/pytorch_transforms.py +++ b/QEfficient/transformers/models/pytorch_transforms.py @@ -261,6 +261,9 @@ QEffGptOssForCausalLM, QEffGptOssMLP, QEffGptOssModel, + QEffPrefillOnlyGptOssAttention, + QEffPrefillOnlyGptOssMLP, + QEffPrefillOnlyGptOssModel, ) from QEfficient.transformers.models.gptj.modeling_gptj import ( QEffGPTJAttention, @@ -634,6 +637,14 @@ def apply(cls, model: nn.Module) -> Tuple[nn.Module, bool]: return model, transformed +class PrefillOnlyTransform(ModuleMappingTransform): + _module_mapping = { + QEffGptOssModel: QEffPrefillOnlyGptOssModel, + QEffGptOssAttention: QEffPrefillOnlyGptOssAttention, + QEffGptOssExperts: QEffPrefillOnlyGptOssMLP, + } + + class SpDTransform: """ Apply generic QEffForCausalLM forward pass to extract `num_speculative_tokens+1` hidden states before computing logits during decode phase and extract last predicted token during prefill. diff --git a/QEfficient/utils/_utils.py b/QEfficient/utils/_utils.py index 131a7fc26..ea22d932f 100644 --- a/QEfficient/utils/_utils.py +++ b/QEfficient/utils/_utils.py @@ -567,6 +567,7 @@ def wrapper(self, *args, **kwargs): export_kwargs=all_args.get("export_kwargs", None), onnx_transform_kwargs=all_args.get("onnx_transform_kwargs", None), use_onnx_subfunctions=all_args.get("use_onnx_subfunctions", False), + prefill_only=all_args.get("prefill_only", False), ) export_dir = export_dir.with_name(export_dir.name + "-" + export_hash) diff --git a/QEfficient/utils/hash_utils.py b/QEfficient/utils/hash_utils.py index 948b72e6a..2c93e3990 100644 --- a/QEfficient/utils/hash_utils.py +++ b/QEfficient/utils/hash_utils.py @@ -68,5 +68,6 @@ def create_export_hash(**kwargs): export_hash_params.update(onnx_transform_kwargs) if export_hash_params.get("peft_config") is not None and not isinstance(export_hash_params["peft_config"], dict): export_hash_params["peft_config"] = export_hash_params["peft_config"].to_dict() + export_hash_params["prefill_only"] = kwargs.get("prefill_only") return hash_dict_params(export_hash_params), export_hash_params diff --git a/examples/gpt_oss_disagg_mode.py b/examples/gpt_oss_disagg_mode.py new file mode 100644 index 000000000..22238de13 --- /dev/null +++ b/examples/gpt_oss_disagg_mode.py @@ -0,0 +1,47 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +# SPDX-License-Identifier: BSD-3-Clause +# +# ----------------------------------------------------------------------------- + +from transformers import AutoTokenizer, TextStreamer + +from QEfficient import QEFFAutoModelForCausalLM + +model_id = "openai/gpt-oss-20b" # weights are not required to convert to fp32 + +qeff_model = QEFFAutoModelForCausalLM.from_pretrained(model_id, num_hidden_layers=2) +tokenizer = AutoTokenizer.from_pretrained(model_id) + +decode_qpc_path = qeff_model.compile( + prefill_seq_len=1, # Currently we can get best perf using PL=1 i.e. decode-only model, prefill optimizations are being worked on. + ctx_len=256, + num_cores=16, + mxfp6_matmul=True, + mxint8_kv_cache=True, + num_devices=1, + mos=1, + aic_enable_depth_first=True, + num_speculative_tokens=None, + offload_pt_weights=False, +) +prefill_qpc_path = qeff_model.compile( + prefill_seq_len=256, # Currently we can get best perf using PL=1 i.e. decode-only model, prefill optimizations are being worked on. + ctx_len=256, + num_cores=16, + mxfp6_matmul=True, + mxint8_kv_cache=True, + num_devices=1, + mos=1, + aic_enable_depth_first=True, + num_speculative_tokens=None, + prefill_only=True, +) +# print(f"qpc path is {qpc_path}") +# streamer = TextStreamer(tokenizer) +# exec_info = qeff_model.generate( +# tokenizer, +# prompts="Who is your creator? and What all you are allowed to do?", +# device_id=[0, 1, 2, 3], +# ) From 0e3a6739c45fd05c9f62c69b4d1e3266378d6e7d Mon Sep 17 00:00:00 2001 From: Onkar Chougule Date: Wed, 5 Nov 2025 06:24:20 +0000 Subject: [PATCH 14/37] fixed mapping Signed-off-by: Onkar Chougule --- QEfficient/transformers/models/pytorch_transforms.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/QEfficient/transformers/models/pytorch_transforms.py b/QEfficient/transformers/models/pytorch_transforms.py index 1d7324786..5a828b04c 100644 --- a/QEfficient/transformers/models/pytorch_transforms.py +++ b/QEfficient/transformers/models/pytorch_transforms.py @@ -641,7 +641,7 @@ class PrefillOnlyTransform(ModuleMappingTransform): _module_mapping = { QEffGptOssModel: QEffPrefillOnlyGptOssModel, QEffGptOssAttention: QEffPrefillOnlyGptOssAttention, - QEffGptOssExperts: QEffPrefillOnlyGptOssMLP, + QEffGptOssMLP: QEffPrefillOnlyGptOssMLP, } From 3ce4320243a96861d981d16f46288f9f7b24acd2 Mon Sep 17 00:00:00 2001 From: Onkar Chougule Date: Thu, 6 Nov 2025 07:58:55 +0000 Subject: [PATCH 15/37] added test Signed-off-by: Onkar Chougule --- QEfficient/base/modeling_qeff.py | 1 + .../models/gpt_oss/modeling_gpt_oss.py | 2 +- .../transformers/models/modeling_auto.py | 14 +- .../transformers/models/pytorch_transforms.py | 4 + .../transformers/quantizers/__init__.py | 4 +- examples/gpt_oss_disagg_mode.py | 133 ++++++++++++-- tests/transformers/models/test_disagg_mode.py | 171 ++++++++++++++++++ 7 files changed, 311 insertions(+), 18 deletions(-) create mode 100644 tests/transformers/models/test_disagg_mode.py diff --git a/QEfficient/base/modeling_qeff.py b/QEfficient/base/modeling_qeff.py index a0b5d1cdb..b7078b7c9 100644 --- a/QEfficient/base/modeling_qeff.py +++ b/QEfficient/base/modeling_qeff.py @@ -66,6 +66,7 @@ def __init__(self, model: torch.nn.Module, **kwargs) -> None: super().__init__() self.model = model self.hash_params = create_model_params(self, **kwargs) + self.prefill_enabled = False self.prefill_onnx_path: Optional[str] = None self.onnx_path: Optional[str] = None self.qpc_path: Optional[str] = None diff --git a/QEfficient/transformers/models/gpt_oss/modeling_gpt_oss.py b/QEfficient/transformers/models/gpt_oss/modeling_gpt_oss.py index 2f88613bc..c439e7bbd 100644 --- a/QEfficient/transformers/models/gpt_oss/modeling_gpt_oss.py +++ b/QEfficient/transformers/models/gpt_oss/modeling_gpt_oss.py @@ -422,7 +422,7 @@ def eager_attention_forward_blocked( value_states = repeat_kv(value, module.num_key_value_groups) BS, NH, CL, DH = query.shape - target_blocks = int(os.environ.get("NUM_BLOCKS")) + target_blocks = int(os.environ.get("NUM_BLOCKS", 1)) block_positions = [] for j in range(target_blocks): block_positions.append(j * (CL // target_blocks)) diff --git a/QEfficient/transformers/models/modeling_auto.py b/QEfficient/transformers/models/modeling_auto.py index e56486058..84191a3ea 100644 --- a/QEfficient/transformers/models/modeling_auto.py +++ b/QEfficient/transformers/models/modeling_auto.py @@ -49,6 +49,7 @@ KVCacheTransform, PoolingTransform, PrefillOnlyTransform, + RevertPrefillOnlyTransform, SamplerTransform, SpDTransform, VlmKVOffloadTransform, @@ -2348,6 +2349,14 @@ class QEFFAutoModelForCausalLM(QEFFBaseModel): _onnx_transforms = [FP16ClipTransform, SplitTensorsTransform] + def prefill(self, enable: Optional[bool] = True): + if enable: + self.model, tf = PrefillOnlyTransform.apply(self.model) + self.prefill_enabled = True + else: + self.model, tf = RevertPrefillOnlyTransform.apply(self.model) + self.prefill_enabled = False + def __init__( self, model: nn.Module, @@ -2586,8 +2595,9 @@ def export( os.environ["NUM_BLOCKS"] = str(prefill_seq_len // block_size) if self.model.config.model_type in SPECIALIZED_PREFILL_ONLY_MODEL_ARCH: - self.model, tf = PrefillOnlyTransform.apply(self.model) - + self.prefill(True) + else: + self.prefill(False) bs: int = constants.ONNX_EXPORT_EXAMPLE_BATCH_SIZE seq_len: int = constants.ONNX_EXPORT_EXAMPLE_SEQ_LEN if not prefill_only else prefill_seq_len // block_size fbs: int = constants.ONNX_EXPORT_EXAMPLE_FBS diff --git a/QEfficient/transformers/models/pytorch_transforms.py b/QEfficient/transformers/models/pytorch_transforms.py index 5a828b04c..3bdd21d54 100644 --- a/QEfficient/transformers/models/pytorch_transforms.py +++ b/QEfficient/transformers/models/pytorch_transforms.py @@ -645,6 +645,10 @@ class PrefillOnlyTransform(ModuleMappingTransform): } +class RevertPrefillOnlyTransform(ModuleMappingTransform): + _module_mapping = {v: k for k, v in PrefillOnlyTransform._module_mapping.items()} + + class SpDTransform: """ Apply generic QEffForCausalLM forward pass to extract `num_speculative_tokens+1` hidden states before computing logits during decode phase and extract last predicted token during prefill. diff --git a/QEfficient/transformers/quantizers/__init__.py b/QEfficient/transformers/quantizers/__init__.py index dfadc00ef..dc2308e99 100644 --- a/QEfficient/transformers/quantizers/__init__.py +++ b/QEfficient/transformers/quantizers/__init__.py @@ -5,6 +5,6 @@ # # ----------------------------------------------------------------------------- -from QEfficient.transformers.quantizers.auto import replace_transformers_quantizers +from QEfficient.transformers.quantizers.auto import replace_transformers_quantizers, undo_transformers_quantizers -__all__ = ["replace_transformers_quantizers"] +__all__ = ["replace_transformers_quantizers", "undo_transformers_quantizers"] diff --git a/examples/gpt_oss_disagg_mode.py b/examples/gpt_oss_disagg_mode.py index 22238de13..9f196e002 100644 --- a/examples/gpt_oss_disagg_mode.py +++ b/examples/gpt_oss_disagg_mode.py @@ -5,18 +5,76 @@ # # ----------------------------------------------------------------------------- -from transformers import AutoTokenizer, TextStreamer +import time + +import numpy as np +import torch +from transformers import AutoModelForCausalLM, AutoTokenizer, HybridCache from QEfficient import QEFFAutoModelForCausalLM +from QEfficient.generation.cloud_infer import QAICInferenceSession model_id = "openai/gpt-oss-20b" # weights are not required to convert to fp32 +# prompt = """ +# Billions of years ago, in the vast emptiness of the early universe, tiny fluctuations in the density of matter began to grow under the influence of gravity. Clouds of gas—mostly hydrogen and helium—started to collapse, forming the first stars. These stars grouped together, bound by gravity, creating the earliest galaxies. +# Over time, these galaxies merged, collided, and evolved, shaping their spiral arms, elliptical forms, or irregular structures. Within their swirling depths, stars were born and died, enriching the galactic gas with heavier elements. These elements became the building blocks for planets, moons, and eventually life. +# Life is a very interesting phenomenon that occured in this universe +# """ +# prompt = "Once upon a time" +prompt = """ +Once upon a time, in a small town, there lived a young boy named Alex. Alex was a curious and adventurous child, always eager to explore the world around him. One day, while playing in the park, Alex stumbled upon a mysterious old book hidden beneath a pile of leaves. The book was filled with stories of distant lands, magical creatures, and extraordinary adventures. + +As Alex flipped through the pages, he discovered a map that led to a hidden treasure. Excited by the prospect of a real-life treasure hunt, Alex decided to embark on a thrilling journey. He packed his backpack with snacks, a flashlight, and a compass, and set off into the unknown. -qeff_model = QEFFAutoModelForCausalLM.from_pretrained(model_id, num_hidden_layers=2) +The path to the treasure was not an easy one. Alex had to navigate through dense forests, cross rickety bridges, and solve riddles that guarded the treasure's location. +""" +all_outputs = [] +# Run prefill tokenizer = AutoTokenizer.from_pretrained(model_id) +PREFILL_SEQ_LEN = 256 +CTX_LEN = 256 +inputs = tokenizer(prompt, return_tensors="np", padding=True) +position_ids = inputs["attention_mask"].sum(1, keepdims=True) +padded_len = inputs["input_ids"].shape[1] +num_chunks = -(padded_len // -PREFILL_SEQ_LEN) # ceil divide without float +padded_len = num_chunks * PREFILL_SEQ_LEN # Convert to a multiple of prompt_len + +# Initialize variables specific to request +# Calculate the max generation length. +max_gen_len = CTX_LEN - position_ids.max() +generation_len = max_gen_len + +# model = AutoModelForCausalLM.from_pretrained(model_id, num_hidden_layers=2) +# config = model.config +# inputs = tokenizer(prompt, return_tensors="np", padding="max_length", max_length=padded_len) +# inputs["position_ids"] = np.where(inputs.pop("attention_mask"), np.arange(padded_len), -1) +# inputs.pop("token_type_ids", None) +# inputs = {k: torch.from_numpy(v).to(model.device) for k, v in inputs.items()} +# cache = HybridCache(config=config, batch_size=1, max_cache_len=8192) +# out = model(**tokenizer(prompt, return_tensors="pt"), past_key_values=cache) + + +qeff_model = QEFFAutoModelForCausalLM.from_pretrained(model_id) +config = qeff_model.model.config +inputs = tokenizer(prompt, return_tensors="np", padding="max_length", max_length=padded_len) +inputs["position_ids"] = np.where(inputs.pop("attention_mask"), np.arange(padded_len), -1) +inputs.pop("token_type_ids", None) +inputs = {k: torch.from_numpy(v) for k, v in inputs.items()} +past_key_values = [] +for i in range(config.num_hidden_layers): + cache_len = config.sliding_window if i % 2 == 0 else PREFILL_SEQ_LEN + pad_shape = (1, 8, cache_len, 64) + past_key = torch.zeros((pad_shape), dtype=torch.float32) + past_value = torch.zeros((pad_shape), dtype=torch.float32) + pkv = (past_key, past_value) + past_key_values.append(pkv) +inputs["past_key_values"] = past_key_values + +# qeff_out = qeff_model.model(**inputs) decode_qpc_path = qeff_model.compile( - prefill_seq_len=1, # Currently we can get best perf using PL=1 i.e. decode-only model, prefill optimizations are being worked on. - ctx_len=256, + prefill_seq_len=1, + ctx_len=CTX_LEN, num_cores=16, mxfp6_matmul=True, mxint8_kv_cache=True, @@ -27,8 +85,8 @@ offload_pt_weights=False, ) prefill_qpc_path = qeff_model.compile( - prefill_seq_len=256, # Currently we can get best perf using PL=1 i.e. decode-only model, prefill optimizations are being worked on. - ctx_len=256, + prefill_seq_len=PREFILL_SEQ_LEN, + ctx_len=CTX_LEN, num_cores=16, mxfp6_matmul=True, mxint8_kv_cache=True, @@ -38,10 +96,59 @@ num_speculative_tokens=None, prefill_only=True, ) -# print(f"qpc path is {qpc_path}") -# streamer = TextStreamer(tokenizer) -# exec_info = qeff_model.generate( -# tokenizer, -# prompts="Who is your creator? and What all you are allowed to do?", -# device_id=[0, 1, 2, 3], -# ) + +prefill_session = QAICInferenceSession(prefill_qpc_path) + +logits_out_placeholder = np.zeros((1, 1, 201088), dtype=np.float32) +prefill_session.set_buffers({"logits": logits_out_placeholder}) +inputs.pop("past_key_values") +inputs = {k: v.detach().numpy() for k, v in inputs.items()} +st = time.time() +qpc_out = prefill_session.run(inputs) +print(f"time for prefill_run={time.time() - st} sec\n") + +decode_session = QAICInferenceSession(decode_qpc_path) +decode_session.set_buffers({"logits": logits_out_placeholder}) + +decode_inputs = { + "input_ids": np.argmax(qpc_out["logits"]).reshape(1, 1), + "position_ids": np.max(inputs["position_ids"]).reshape(1, 1) + 1, +} +print("pos_id for decodee", decode_inputs["position_ids"]) + +all_outputs.append(decode_inputs["input_ids"][0][0]) +for i in range(config.num_hidden_layers): + if i % 2 == 0 and decode_inputs["position_ids"] >= config.sliding_window: + k = qpc_out[f"past_key.{i}_RetainedState"] + v = qpc_out[f"past_value.{i}_RetainedState"] + mod_pos_id = config.sliding_window - decode_inputs["position_ids"][0][0] % config.sliding_window + decode_inputs[f"past_key.{i}"] = np.concatenate((k[:, :, mod_pos_id:, :], k[:, :, :mod_pos_id, :]), axis=-2) + decode_inputs[f"past_value.{i}"] = np.concatenate((v[:, :, mod_pos_id:, :], v[:, :, :mod_pos_id, :]), axis=-2) + else: + decode_inputs[f"past_key.{i}"] = qpc_out[f"past_key.{i}_RetainedState"] + decode_inputs[f"past_value.{i}"] = qpc_out[f"past_value.{i}_RetainedState"] + +st = time.time() +decode_out = decode_session.run(decode_inputs) +print(f"time for first run of decode with KV as input = {time.time() - st} sec\n") +decode_session.skip_buffers( + [x for x in decode_session.input_names + decode_session.output_names if x.startswith("past_")] +) +pos_id = np.max(decode_inputs["position_ids"]).reshape(1, 1) + 1 +st = time.time() +for i in range(generation_len - 2): + loop_decode_inputs = { + "input_ids": np.argmax(decode_out["logits"]).reshape(1, 1), + "position_ids": pos_id, + } + # for i in range(config.num_hidden_layers): + # loop_decode_inputs[f"past_key.{i}"] = decode_out[f"past_key.{i}_RetainedState"] + # loop_decode_inputs[f"past_value.{i}"] = decode_out[f"past_value.{i}_RetainedState"] + all_outputs.append(loop_decode_inputs["input_ids"][0][0]) + decode_out = decode_session.run(loop_decode_inputs) + pos_id += 1 + + +print(f"time for decode generation = {(time.time() - st) / (generation_len - 2)}") +print(all_outputs) +print(tokenizer.decode(all_outputs)) diff --git a/tests/transformers/models/test_disagg_mode.py b/tests/transformers/models/test_disagg_mode.py new file mode 100644 index 000000000..07106bddc --- /dev/null +++ b/tests/transformers/models/test_disagg_mode.py @@ -0,0 +1,171 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +# SPDX-License-Identifier: BSD-3-Clause +# +# ----------------------------------------------------------------------------- + +import time + +import numpy as np +import pytest +import torch +from transformers import AutoModelForCausalLM, AutoTokenizer, HybridCache + +from QEfficient import QEFFAutoModelForCausalLM +from QEfficient.generation.cloud_infer import QAICInferenceSession +from QEfficient.transformers.quantizers import replace_transformers_quantizers, undo_transformers_quantizers + +replace_transformers_quantizers() + +model_id = "openai/gpt-oss-20b" # weights are not required to convert to fp32 +# prompt = """ +# Billions of years ago, in the vast emptiness of the early universe, tiny fluctuations in the density of matter began to grow under the influence of gravity. Clouds of gas—mostly hydrogen and helium—started to collapse, forming the first stars. These stars grouped together, bound by gravity, creating the earliest galaxies. +# Over time, these galaxies merged, collided, and evolved, shaping their spiral arms, elliptical forms, or irregular structures. Within their swirling depths, stars were born and died, enriching the galactic gas with heavier elements. These elements became the building blocks for planets, moons, and eventually life. +# Thus, from the quiet whispers of cosmic dust, a galaxy emerged—an island of stars, nebulae, and mysteries, drifting through the infinite sea of space. +# As the galaxy matured, its stars danced in intricate orbits, weaving patterns shaped by gravity and time. Supernovae exploded like cosmic fireworks, scattering elements across space and triggering new waves of star formation. Black holes formed at the hearts of galaxies, anchoring their structure and influencing their evolution. Over billions of years, the galaxy became a dynamic ecosystem—where stars are born, live, and die—each cycle adding to the richness of the cosmic tapestry. +# """ +prompt = """ +Once upon a time, in a small town, there lived a young boy named Alex. Alex was a curious and adventurous child, always eager to explore the world around him. One day, while playing in the park, Alex stumbled upon a mysterious old book hidden beneath a pile of leaves. The book was filled with stories of distant lands, magical creatures, and extraordinary adventures. + +As Alex flipped through the pages, he discovered a map that led to a hidden treasure. Excited by the prospect of a real-life treasure hunt, Alex decided to embark on a thrilling journey. He packed his backpack with snacks, a flashlight, and a compass, and set off into the unknown. + +The path to the treasure was not an easy one. Alex had to navigate through dense forests, cross rickety bridges, and solve riddles that guarded the treasure's location. +""" + + +@pytest.mark.parametrize("model_id", [model_id]) +def test_disagg_mode(model_id): + all_outputs = [] + # Run prefill + tokenizer = AutoTokenizer.from_pretrained(model_id) + PREFILL_SEQ_LEN = 256 + CTX_LEN = 256 + inputs = tokenizer(prompt, return_tensors="np", padding=True) + position_ids = inputs["attention_mask"].sum(1, keepdims=True) + padded_len = inputs["input_ids"].shape[1] + num_chunks = -(padded_len // -PREFILL_SEQ_LEN) # ceil divide without float + padded_len = num_chunks * PREFILL_SEQ_LEN # Convert to a multiple of prompt_len + + # Initialize variables specific to request + # Calculate the max generation length. + max_gen_len = CTX_LEN - position_ids.max() + generation_len = 50 + + # model = AutoModelForCausalLM.from_pretrained(model_id, num_hidden_layers=2) + model = AutoModelForCausalLM.from_pretrained(model_id) + config = model.config + inputs = tokenizer(prompt, return_tensors="np", padding="max_length", max_length=padded_len) + inputs["position_ids"] = np.where(inputs.pop("attention_mask"), np.arange(padded_len), -1) + inputs.pop("token_type_ids", None) + inputs = {k: torch.from_numpy(v).to(model.device) for k, v in inputs.items()} + cache = HybridCache(config=config, batch_size=1, max_cache_len=CTX_LEN) + ins = tokenizer(prompt, return_tensors="pt") + out = model(**ins, past_key_values=cache) + puts = { + "input_ids": out.logits[:, -1, :].argmax().reshape(1, -1), + "position_ids": ins["input_ids"].shape[-1].reshape(1, -1), + } + import ipdb + + ipdb.set_trace() + new_out = model(**puts, past_key_values=cache) + model.generation_config.do_sample = False + orig_all_out = model.generate( + **tokenizer(prompt, return_tensors="pt"), past_key_values=cache, max_new_tokens=max_gen_len + ) + undo_transformers_quantizers() + + qeff_model = QEFFAutoModelForCausalLM.from_pretrained(model_id) + # qeff_model = QEFFAutoModelForCausalLM.from_pretrained(model_id, num_hidden_layers=2) + qeff_model.prefill(True) + config = qeff_model.model.config + inputs = tokenizer(prompt, return_tensors="np", padding="max_length", max_length=padded_len) + inputs["position_ids"] = np.where(inputs.pop("attention_mask"), np.arange(padded_len), -1) + inputs.pop("token_type_ids", None) + inputs = {k: torch.from_numpy(v) for k, v in inputs.items()} + past_key_values = [] + for i in range(config.num_hidden_layers): + cache_len = 128 if i % 2 == 0 else PREFILL_SEQ_LEN + pad_shape = (1, 8, cache_len, 64) + past_key = torch.zeros((pad_shape), dtype=torch.float32) + past_value = torch.zeros((pad_shape), dtype=torch.float32) + pkv = (past_key, past_value) + past_key_values.append(pkv) + inputs["past_key_values"] = past_key_values + + qeff_out = qeff_model.model(**inputs) + + import ipdb + + ipdb.set_trace() + + decode_qpc_path = qeff_model.compile( + prefill_seq_len=1, + ctx_len=CTX_LEN, + num_cores=16, + mxfp6_matmul=True, + mxint8_kv_cache=True, + num_devices=1, + mos=1, + aic_enable_depth_first=True, + num_speculative_tokens=None, + offload_pt_weights=False, + ) + prefill_qpc_path = qeff_model.compile( + prefill_seq_len=PREFILL_SEQ_LEN, + ctx_len=CTX_LEN, + num_cores=16, + mxfp6_matmul=True, + mxint8_kv_cache=True, + num_devices=1, + mos=1, + aic_enable_depth_first=True, + num_speculative_tokens=None, + prefill_only=True, + ) + + prefill_session = QAICInferenceSession(prefill_qpc_path) + + logits_out_placeholder = np.zeros((1, 1, 201088), dtype=np.float32) + prefill_session.set_buffers({"logits": logits_out_placeholder}) + inputs.pop("past_key_values") + inputs = {k: v.detach().numpy() for k, v in inputs.items()} + st = time.time() + qpc_out = prefill_session.run(inputs) + print(f"time for prefill_run={time.time() - st} sec\n") + import ipdb + + ipdb.set_trace() + decode_session = QAICInferenceSession(decode_qpc_path) + decode_session.set_buffers({"logits": logits_out_placeholder}) + decode_session.skip_buffers( + [x for x in decode_session.input_names + decode_session.output_names if x.startswith("past_")] + ) + + decode_inputs = { + "input_ids": np.argmax(qpc_out["logits"]).reshape(1, 1), + "position_ids": np.max(inputs["position_ids"]).reshape(1, 1) + 1, + } + + all_outputs.append(decode_inputs["input_ids"][0][0]) + for i in range(config.num_hidden_layers): + decode_inputs[f"past_key.{i}"] = qpc_out[f"past_key.{i}_RetainedState"] + decode_inputs[f"past_value.{i}"] = qpc_out[f"past_value.{i}_RetainedState"] + + st = time.time() + decode_out = decode_session.run(decode_inputs) + print(f"time for first run of decode with KV as input = {time.time() - st} sec\n") + + st = time.time() + for i in range(generation_len - 2): + loop_decode_inputs = { + "input_ids": np.argmax(decode_out["logits"]).reshape(1, 1), + "position_ids": np.max(inputs["position_ids"]).reshape(1, 1) + 1, + } + all_outputs.append(loop_decode_inputs["input_ids"][0][0]) + decode_out = decode_session.run(loop_decode_inputs) + + print(f"time for decode generation = {(time.time() - st) / (generation_len - 2)}") + print(all_outputs) + print(tokenizer.decode(all_outputs)) From e9296161844dbf555413eda4629321343ef6940c Mon Sep 17 00:00:00 2001 From: Onkar Chougule Date: Thu, 6 Nov 2025 10:44:31 +0000 Subject: [PATCH 16/37] added test Signed-off-by: Onkar Chougule --- .../models/gpt_oss/modeling_gpt_oss.py | 6 +- tests/transformers/models/test_disagg_mode.py | 104 ++++-------------- 2 files changed, 23 insertions(+), 87 deletions(-) diff --git a/QEfficient/transformers/models/gpt_oss/modeling_gpt_oss.py b/QEfficient/transformers/models/gpt_oss/modeling_gpt_oss.py index c439e7bbd..8545a859f 100644 --- a/QEfficient/transformers/models/gpt_oss/modeling_gpt_oss.py +++ b/QEfficient/transformers/models/gpt_oss/modeling_gpt_oss.py @@ -487,7 +487,8 @@ def forward( hidden_shape = (*input_shape, -1, self.head_dim) key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2) value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) - cos, sin = self.rotary_emb(value_states, seq_len=32 * 1024) + + cos, sin = self.rotary_emb(value_states, seq_len=getattr(self.config, "max_position_embeddings", 32 * 1024)) query_states, key_states = qeff_apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) if past_key_value is not None: @@ -565,7 +566,7 @@ def forward( key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2) value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) - cos, sin = self.rotary_emb(value_states, seq_len=32 * 1024) + cos, sin = self.rotary_emb(value_states, seq_len=getattr(self.config, "max_position_embeddings", 32 * 1024)) query_states, key_states = qeff_apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) if past_key_value is not None: @@ -804,7 +805,6 @@ def forward( ) hidden_states = inputs_embeds - # position_embeddings = self.rotary_emb(hidden_states, position_ids) # decoder layers all_hidden_states = () if output_hidden_states else None diff --git a/tests/transformers/models/test_disagg_mode.py b/tests/transformers/models/test_disagg_mode.py index 07106bddc..fdbd374ff 100644 --- a/tests/transformers/models/test_disagg_mode.py +++ b/tests/transformers/models/test_disagg_mode.py @@ -19,41 +19,33 @@ replace_transformers_quantizers() model_id = "openai/gpt-oss-20b" # weights are not required to convert to fp32 -# prompt = """ -# Billions of years ago, in the vast emptiness of the early universe, tiny fluctuations in the density of matter began to grow under the influence of gravity. Clouds of gas—mostly hydrogen and helium—started to collapse, forming the first stars. These stars grouped together, bound by gravity, creating the earliest galaxies. -# Over time, these galaxies merged, collided, and evolved, shaping their spiral arms, elliptical forms, or irregular structures. Within their swirling depths, stars were born and died, enriching the galactic gas with heavier elements. These elements became the building blocks for planets, moons, and eventually life. -# Thus, from the quiet whispers of cosmic dust, a galaxy emerged—an island of stars, nebulae, and mysteries, drifting through the infinite sea of space. -# As the galaxy matured, its stars danced in intricate orbits, weaving patterns shaped by gravity and time. Supernovae exploded like cosmic fireworks, scattering elements across space and triggering new waves of star formation. Black holes formed at the hearts of galaxies, anchoring their structure and influencing their evolution. Over billions of years, the galaxy became a dynamic ecosystem—where stars are born, live, and die—each cycle adding to the richness of the cosmic tapestry. -# """ -prompt = """ + +prompt2 = """ Once upon a time, in a small town, there lived a young boy named Alex. Alex was a curious and adventurous child, always eager to explore the world around him. One day, while playing in the park, Alex stumbled upon a mysterious old book hidden beneath a pile of leaves. The book was filled with stories of distant lands, magical creatures, and extraordinary adventures. As Alex flipped through the pages, he discovered a map that led to a hidden treasure. Excited by the prospect of a real-life treasure hunt, Alex decided to embark on a thrilling journey. He packed his backpack with snacks, a flashlight, and a compass, and set off into the unknown. The path to the treasure was not an easy one. Alex had to navigate through dense forests, cross rickety bridges, and solve riddles that guarded the treasure's location. """ +prompt1 = "Once upon a time" + +prompts = [prompt1, prompt2] +@pytest.mark.on_qaic @pytest.mark.parametrize("model_id", [model_id]) -def test_disagg_mode(model_id): - all_outputs = [] +@pytest.mark.parametrize("prompt", prompts) +def test_disagg_mode_prefill(model_id, prompt): # Run prefill tokenizer = AutoTokenizer.from_pretrained(model_id) PREFILL_SEQ_LEN = 256 CTX_LEN = 256 inputs = tokenizer(prompt, return_tensors="np", padding=True) - position_ids = inputs["attention_mask"].sum(1, keepdims=True) padded_len = inputs["input_ids"].shape[1] num_chunks = -(padded_len // -PREFILL_SEQ_LEN) # ceil divide without float padded_len = num_chunks * PREFILL_SEQ_LEN # Convert to a multiple of prompt_len - # Initialize variables specific to request - # Calculate the max generation length. - max_gen_len = CTX_LEN - position_ids.max() - generation_len = 50 - - # model = AutoModelForCausalLM.from_pretrained(model_id, num_hidden_layers=2) - model = AutoModelForCausalLM.from_pretrained(model_id) + model = AutoModelForCausalLM.from_pretrained(model_id, num_hidden_layers=2) config = model.config inputs = tokenizer(prompt, return_tensors="np", padding="max_length", max_length=padded_len) inputs["position_ids"] = np.where(inputs.pop("attention_mask"), np.arange(padded_len), -1) @@ -62,22 +54,12 @@ def test_disagg_mode(model_id): cache = HybridCache(config=config, batch_size=1, max_cache_len=CTX_LEN) ins = tokenizer(prompt, return_tensors="pt") out = model(**ins, past_key_values=cache) - puts = { - "input_ids": out.logits[:, -1, :].argmax().reshape(1, -1), - "position_ids": ins["input_ids"].shape[-1].reshape(1, -1), - } - import ipdb - - ipdb.set_trace() - new_out = model(**puts, past_key_values=cache) - model.generation_config.do_sample = False - orig_all_out = model.generate( - **tokenizer(prompt, return_tensors="pt"), past_key_values=cache, max_new_tokens=max_gen_len - ) + undo_transformers_quantizers() - qeff_model = QEFFAutoModelForCausalLM.from_pretrained(model_id) - # qeff_model = QEFFAutoModelForCausalLM.from_pretrained(model_id, num_hidden_layers=2) + qeff_model = QEFFAutoModelForCausalLM.from_pretrained( + model_id, num_hidden_layers=2, max_position_embeddings=64 * 1024 + ) qeff_model.prefill(True) config = qeff_model.model.config inputs = tokenizer(prompt, return_tensors="np", padding="max_length", max_length=padded_len) @@ -96,28 +78,15 @@ def test_disagg_mode(model_id): qeff_out = qeff_model.model(**inputs) - import ipdb + # Check our pytorch implementation + assert (qeff_out.logits - out.logits[:, -1, :]).abs().max() < 1e-4 - ipdb.set_trace() - - decode_qpc_path = qeff_model.compile( - prefill_seq_len=1, - ctx_len=CTX_LEN, - num_cores=16, - mxfp6_matmul=True, - mxint8_kv_cache=True, - num_devices=1, - mos=1, - aic_enable_depth_first=True, - num_speculative_tokens=None, - offload_pt_weights=False, - ) prefill_qpc_path = qeff_model.compile( prefill_seq_len=PREFILL_SEQ_LEN, ctx_len=CTX_LEN, num_cores=16, - mxfp6_matmul=True, - mxint8_kv_cache=True, + mxfp6_matmul=False, + mxint8_kv_cache=False, num_devices=1, mos=1, aic_enable_depth_first=True, @@ -126,7 +95,6 @@ def test_disagg_mode(model_id): ) prefill_session = QAICInferenceSession(prefill_qpc_path) - logits_out_placeholder = np.zeros((1, 1, 201088), dtype=np.float32) prefill_session.set_buffers({"logits": logits_out_placeholder}) inputs.pop("past_key_values") @@ -134,38 +102,6 @@ def test_disagg_mode(model_id): st = time.time() qpc_out = prefill_session.run(inputs) print(f"time for prefill_run={time.time() - st} sec\n") - import ipdb - - ipdb.set_trace() - decode_session = QAICInferenceSession(decode_qpc_path) - decode_session.set_buffers({"logits": logits_out_placeholder}) - decode_session.skip_buffers( - [x for x in decode_session.input_names + decode_session.output_names if x.startswith("past_")] - ) - - decode_inputs = { - "input_ids": np.argmax(qpc_out["logits"]).reshape(1, 1), - "position_ids": np.max(inputs["position_ids"]).reshape(1, 1) + 1, - } - - all_outputs.append(decode_inputs["input_ids"][0][0]) - for i in range(config.num_hidden_layers): - decode_inputs[f"past_key.{i}"] = qpc_out[f"past_key.{i}_RetainedState"] - decode_inputs[f"past_value.{i}"] = qpc_out[f"past_value.{i}_RetainedState"] - - st = time.time() - decode_out = decode_session.run(decode_inputs) - print(f"time for first run of decode with KV as input = {time.time() - st} sec\n") - - st = time.time() - for i in range(generation_len - 2): - loop_decode_inputs = { - "input_ids": np.argmax(decode_out["logits"]).reshape(1, 1), - "position_ids": np.max(inputs["position_ids"]).reshape(1, 1) + 1, - } - all_outputs.append(loop_decode_inputs["input_ids"][0][0]) - decode_out = decode_session.run(loop_decode_inputs) - - print(f"time for decode generation = {(time.time() - st) / (generation_len - 2)}") - print(all_outputs) - print(tokenizer.decode(all_outputs)) + del prefill_session + # Check QAIC output isclose with QEFF pytorch output + assert (torch.from_numpy(qpc_out["logits"]) - qeff_out.logits).abs().max() < 5e-2 From 40ab876d977e43ea7835b240dc0c2e9294e7ab09 Mon Sep 17 00:00:00 2001 From: Onkar Chougule Date: Thu, 6 Nov 2025 10:46:23 +0000 Subject: [PATCH 17/37] made example not ugly Signed-off-by: Onkar Chougule --- examples/gpt_oss_disagg_mode.py | 22 ++-------------------- 1 file changed, 2 insertions(+), 20 deletions(-) diff --git a/examples/gpt_oss_disagg_mode.py b/examples/gpt_oss_disagg_mode.py index 9f196e002..ee03f573a 100644 --- a/examples/gpt_oss_disagg_mode.py +++ b/examples/gpt_oss_disagg_mode.py @@ -9,18 +9,13 @@ import numpy as np import torch -from transformers import AutoModelForCausalLM, AutoTokenizer, HybridCache +from transformers import AutoTokenizer from QEfficient import QEFFAutoModelForCausalLM from QEfficient.generation.cloud_infer import QAICInferenceSession model_id = "openai/gpt-oss-20b" # weights are not required to convert to fp32 -# prompt = """ -# Billions of years ago, in the vast emptiness of the early universe, tiny fluctuations in the density of matter began to grow under the influence of gravity. Clouds of gas—mostly hydrogen and helium—started to collapse, forming the first stars. These stars grouped together, bound by gravity, creating the earliest galaxies. -# Over time, these galaxies merged, collided, and evolved, shaping their spiral arms, elliptical forms, or irregular structures. Within their swirling depths, stars were born and died, enriching the galactic gas with heavier elements. These elements became the building blocks for planets, moons, and eventually life. -# Life is a very interesting phenomenon that occured in this universe -# """ -# prompt = "Once upon a time" + prompt = """ Once upon a time, in a small town, there lived a young boy named Alex. Alex was a curious and adventurous child, always eager to explore the world around him. One day, while playing in the park, Alex stumbled upon a mysterious old book hidden beneath a pile of leaves. The book was filled with stories of distant lands, magical creatures, and extraordinary adventures. @@ -44,15 +39,6 @@ max_gen_len = CTX_LEN - position_ids.max() generation_len = max_gen_len -# model = AutoModelForCausalLM.from_pretrained(model_id, num_hidden_layers=2) -# config = model.config -# inputs = tokenizer(prompt, return_tensors="np", padding="max_length", max_length=padded_len) -# inputs["position_ids"] = np.where(inputs.pop("attention_mask"), np.arange(padded_len), -1) -# inputs.pop("token_type_ids", None) -# inputs = {k: torch.from_numpy(v).to(model.device) for k, v in inputs.items()} -# cache = HybridCache(config=config, batch_size=1, max_cache_len=8192) -# out = model(**tokenizer(prompt, return_tensors="pt"), past_key_values=cache) - qeff_model = QEFFAutoModelForCausalLM.from_pretrained(model_id) config = qeff_model.model.config @@ -70,7 +56,6 @@ past_key_values.append(pkv) inputs["past_key_values"] = past_key_values -# qeff_out = qeff_model.model(**inputs) decode_qpc_path = qeff_model.compile( prefill_seq_len=1, @@ -141,9 +126,6 @@ "input_ids": np.argmax(decode_out["logits"]).reshape(1, 1), "position_ids": pos_id, } - # for i in range(config.num_hidden_layers): - # loop_decode_inputs[f"past_key.{i}"] = decode_out[f"past_key.{i}_RetainedState"] - # loop_decode_inputs[f"past_value.{i}"] = decode_out[f"past_value.{i}_RetainedState"] all_outputs.append(loop_decode_inputs["input_ids"][0][0]) decode_out = decode_session.run(loop_decode_inputs) pos_id += 1 From b9defbef84f2f59ced7c0b143558cb20f1a8c856 Mon Sep 17 00:00:00 2001 From: Onkar Chougule Date: Thu, 6 Nov 2025 15:21:26 +0000 Subject: [PATCH 18/37] fixed tests Signed-off-by: Onkar Chougule --- QEfficient/base/modeling_qeff.py | 9 +++++---- .../transformers/models/gpt_oss/modeling_gpt_oss.py | 4 ++-- QEfficient/transformers/models/modeling_auto.py | 7 +++++++ tests/transformers/test_causal_lm.py | 2 ++ 4 files changed, 16 insertions(+), 6 deletions(-) diff --git a/QEfficient/base/modeling_qeff.py b/QEfficient/base/modeling_qeff.py index b7078b7c9..33565c49d 100644 --- a/QEfficient/base/modeling_qeff.py +++ b/QEfficient/base/modeling_qeff.py @@ -373,11 +373,12 @@ def _compile( if prefill_only and self.prefill_onnx_path is None: kwargs.update({"prefill_only": prefill_only, "prefill_seq_len": specializations[0].get("seq_len")}) self.export(**kwargs) - onnx_path = Path(onnx_path or self.prefill_onnx_path) - - if onnx_path is None and self.onnx_path is None: + onnx_path = Path(self.prefill_onnx_path) + elif onnx_path is None: self.export(**kwargs) - onnx_path = Path(onnx_path or self.onnx_path) + onnx_path = Path(self.onnx_path) + else: + onnx_path = Path(onnx_path) compile_dir = Path(compile_dir or onnx_path.parent) qpc_path = compile_dir / "qpc" diff --git a/QEfficient/transformers/models/gpt_oss/modeling_gpt_oss.py b/QEfficient/transformers/models/gpt_oss/modeling_gpt_oss.py index 8545a859f..6fca620aa 100644 --- a/QEfficient/transformers/models/gpt_oss/modeling_gpt_oss.py +++ b/QEfficient/transformers/models/gpt_oss/modeling_gpt_oss.py @@ -488,7 +488,7 @@ def forward( key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2) value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) - cos, sin = self.rotary_emb(value_states, seq_len=getattr(self.config, "max_position_embeddings", 32 * 1024)) + cos, sin = self.rotary_emb(value_states, seq_len=getattr(self.config, "max_seq_len_cached", 32 * 1024)) query_states, key_states = qeff_apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) if past_key_value is not None: @@ -566,7 +566,7 @@ def forward( key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2) value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) - cos, sin = self.rotary_emb(value_states, seq_len=getattr(self.config, "max_position_embeddings", 32 * 1024)) + cos, sin = self.rotary_emb(value_states, seq_len=getattr(self.config, "max_seq_len_cached", 32 * 1024)) query_states, key_states = qeff_apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) if past_key_value is not None: diff --git a/QEfficient/transformers/models/modeling_auto.py b/QEfficient/transformers/models/modeling_auto.py index 84191a3ea..182212049 100644 --- a/QEfficient/transformers/models/modeling_auto.py +++ b/QEfficient/transformers/models/modeling_auto.py @@ -2362,6 +2362,7 @@ def __init__( model: nn.Module, continuous_batching: bool = False, qaic_config: Optional[dict] = None, + max_seq_len_cached: Optional[int] = None, **kwargs, ): """ @@ -2409,6 +2410,9 @@ def __init__( # Set use_cache=True to get KV values as output during ONNX export model.config.use_cache = True + self.comp_ctx_lengths_prefill, self.comp_ctx_lengths_decode = process_ccl_specializations(qaic_config) + + setattr(model.config, "max_seq_len_cached", max_seq_len_cached) super().__init__(model, qaic_config=qaic_config, **kwargs) self.num_layers = model.config.num_hidden_layers self.continuous_batching = continuous_batching @@ -2421,6 +2425,7 @@ def __init__( if qaic_config: self.ccl_enabled = qaic_config.get("ccl_enabled", False) self.comp_ctx_lengths_prefill, self.comp_ctx_lengths_decode = None, None + self.hash_params["max_seq_len_cached"] = max_seq_len_cached # ---Sampling--- # Note: SamplerTransform should be applied after all other transforms @@ -2460,6 +2465,7 @@ def from_pretrained( pretrained_model_name_or_path, continuous_batching: bool = False, qaic_config: Optional[dict] = None, + max_seq_len_cached: Optional[int] = None, *args, **kwargs, ): @@ -2538,6 +2544,7 @@ def from_pretrained( continuous_batching=continuous_batching, qaic_config=qaic_config, pretrained_model_name_or_path=pretrained_model_name_or_path, + max_seq_len_cached=max_seq_len_cached, **kwargs, ) diff --git a/tests/transformers/test_causal_lm.py b/tests/transformers/test_causal_lm.py index 0810ac6ba..5e5ad4b5d 100644 --- a/tests/transformers/test_causal_lm.py +++ b/tests/transformers/test_causal_lm.py @@ -154,6 +154,7 @@ def test_causal_lm_hash_creation(config, cb, tmp_path): hash_params["peft_config"] = None hash_params["applied_transform_names"] = qeff_model._transform_names() hash_params["qeff_auto_class"] = qeff_model.__class__.__name__ + hash_params["max_seq_len_cached"] = None hash_params["qaic_config"] = None # Create parameters separately for hash creation @@ -204,6 +205,7 @@ def test_causal_lm_hash_creation(config, cb, tmp_path): export_params["output_names"] = output_names export_params["dynamic_axes"] = dynamic_axes hash_params["export_params"] = export_params + hash_params["prefill_only"] = False manual_hash = hash_dict_params(hash_params) assert manual_hash == qeff_model.export_hash From 446f4b64abe19f6aba70a7a6bbe69db855b5df21 Mon Sep 17 00:00:00 2001 From: Onkar Chougule Date: Thu, 6 Nov 2025 15:44:39 +0000 Subject: [PATCH 19/37] fixed tests Signed-off-by: Onkar Chougule --- tests/transformers/models/test_disagg_mode.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/tests/transformers/models/test_disagg_mode.py b/tests/transformers/models/test_disagg_mode.py index fdbd374ff..0e303d389 100644 --- a/tests/transformers/models/test_disagg_mode.py +++ b/tests/transformers/models/test_disagg_mode.py @@ -16,8 +16,6 @@ from QEfficient.generation.cloud_infer import QAICInferenceSession from QEfficient.transformers.quantizers import replace_transformers_quantizers, undo_transformers_quantizers -replace_transformers_quantizers() - model_id = "openai/gpt-oss-20b" # weights are not required to convert to fp32 prompt2 = """ @@ -45,6 +43,7 @@ def test_disagg_mode_prefill(model_id, prompt): num_chunks = -(padded_len // -PREFILL_SEQ_LEN) # ceil divide without float padded_len = num_chunks * PREFILL_SEQ_LEN # Convert to a multiple of prompt_len + replace_transformers_quantizers() model = AutoModelForCausalLM.from_pretrained(model_id, num_hidden_layers=2) config = model.config inputs = tokenizer(prompt, return_tensors="np", padding="max_length", max_length=padded_len) From 099fd6102496399a69d5a1412f27a22a3feb90ed Mon Sep 17 00:00:00 2001 From: Onkar Chougule Date: Fri, 7 Nov 2025 09:31:55 +0000 Subject: [PATCH 20/37] added new test and fixed failing tests Signed-off-by: Onkar Chougule --- QEfficient/peft/auto.py | 4 +- QEfficient/peft/lora/auto.py | 4 +- .../models/gpt_oss/modeling_gpt_oss.py | 178 +++++++++++++++++- .../transformers/models/modeling_auto.py | 44 +++-- QEfficient/utils/hash_utils.py | 1 - scripts/Jenkinsfile | 2 +- tests/peft/test_peft_model.py | 6 +- tests/transformers/test_causal_lm.py | 77 +++++--- 8 files changed, 255 insertions(+), 61 deletions(-) diff --git a/QEfficient/peft/auto.py b/QEfficient/peft/auto.py index e69aebb2b..5bf2d096c 100644 --- a/QEfficient/peft/auto.py +++ b/QEfficient/peft/auto.py @@ -253,7 +253,7 @@ def from_pretrained(cls, pretrained_name_or_path: str, *args, **kwargs): obj = cls._from_pretrained(pretrained_name_or_path, *args, **kwargs) return obj - def export(self, export_dir: Optional[str] = None, use_onnx_subfunctions: bool = False) -> str: + def export(self, export_dir: Optional[str] = None, **kwargs) -> str: """ Export the model with the active adapter to ONNX format. @@ -294,7 +294,7 @@ def export(self, export_dir: Optional[str] = None, use_onnx_subfunctions: bool = export_kwargs={"do_constant_folding": False}, # To avoid merging adapter weights with base weights onnx_transform_kwargs={"adapter_name": self.model.active_adapter}, export_dir=export_dir, - use_onnx_subfunctions=use_onnx_subfunctions, + **kwargs, ) def compile( diff --git a/QEfficient/peft/lora/auto.py b/QEfficient/peft/lora/auto.py index 64fa3f61c..8ff8335f5 100644 --- a/QEfficient/peft/lora/auto.py +++ b/QEfficient/peft/lora/auto.py @@ -327,7 +327,7 @@ def _init_adapter_model(self): # load_weight to model self._load_adapter_weights_to_model() - def export(self, export_dir: Optional[str] = None, use_onnx_subfunctions: bool = False) -> str: + def export(self, export_dir: Optional[str] = None, **kwargs) -> str: """ Export the model with all loaded adapters to ONNX format using ``torch.onnx.export``. @@ -387,7 +387,7 @@ def export(self, export_dir: Optional[str] = None, use_onnx_subfunctions: bool = output_names, dynamic_axes, export_dir=export_dir, - use_onnx_subfunctions=use_onnx_subfunctions, + **kwargs, ) def generate( diff --git a/QEfficient/transformers/models/gpt_oss/modeling_gpt_oss.py b/QEfficient/transformers/models/gpt_oss/modeling_gpt_oss.py index 6fca620aa..7a8120bfc 100644 --- a/QEfficient/transformers/models/gpt_oss/modeling_gpt_oss.py +++ b/QEfficient/transformers/models/gpt_oss/modeling_gpt_oss.py @@ -4,6 +4,7 @@ # SPDX-License-Identifier: BSD-3-Clause # # ----------------------------------------------------------------------------- +import math import os from typing import Callable, Optional, Union @@ -97,6 +98,167 @@ def forward(self, hidden: torch.Tensor): # original shape [B, S, H] return expert_out.view(B, S, H), router_logits + def blocked_ffn_forward(self, hidden: torch.Tensor): + B, S, H = hidden.shape + T = B * S + hidden = hidden.view(T, H) + + # Router computation + router_logits = F.linear(hidden, self.router.weight, self.router.bias) + + # Top-k selection + top_w, top_i = torch.topk(router_logits, self.router.top_k, dim=-1) # both [T, K] + top_w = torch.nn.functional.softmax(top_w, dim=1, dtype=top_w.dtype) + + masked_logits = torch.zeros_like(router_logits) + masked_logits.scatter_(1, top_i, top_w) + + # Routing weights for each expert [T, E] + routing_weights = masked_logits + + # ────────────────── allocate the output tensor ───── + expert_out = hidden.new_zeros((T, H)) # accumulation buffer + target_blocks = int(os.environ.get("NUM_BLOCKS", 1)) + block_positions = [] + for j in range(target_blocks): + block_positions.append(j * (T // target_blocks)) + # ───────────────────────── Expert computation loop ───────────────────────────── + for e in range(self.experts.num_experts): + routing_weight = routing_weights[:, e].unsqueeze(-1) # [T, 1] + + W_g, W_u = self.experts.gate_proj[e], self.experts.up_proj[e] # [H, I], [H, I] + b_g, b_u = self.experts.gate_proj_bias[e], self.experts.up_proj_bias[e] # [I], [I] + W_d = self.experts.down_proj[e] # [I, H] + b_d = self.experts.down_proj_bias[e] # [H] + + block_count = 0 + outs = [] + for block_idx in range(target_blocks): + block_count += 1 + qi = block_positions[block_idx] + + # Calculate block size (last block should be handled with remainder) + if block_idx == target_blocks - 1: + real_q_len = T - qi + else: + real_q_len = block_positions[block_idx + 1] - qi + + tgb = hidden[qi : qi + real_q_len, :] + # Gate and Up projections + # Gate and Up projections + gate = (tgb @ W_g) + b_g # [T, I] + up = (tgb @ W_u) + b_u # [T, I] + + # Apply GptOss activation with clamping + gate = gate.clamp(min=None, max=self.experts.limit) + up = up.clamp(min=-self.experts.limit, max=self.experts.limit) + + # GLU activation + glu = gate * torch.sigmoid(gate * self.experts.alpha) + intermediate = (up + 1) * glu # [T, I] + + # Down projection + down_out_block = (intermediate @ W_d) + b_d # [T, H] + + outs.append(down_out_block) + + down_out = torch.cat(outs, dim=0) + + # Apply routing weights and accumulate + masked_down = torch.where(routing_weight > 0, down_out * routing_weight, torch.zeros_like(expert_out)) + expert_out += masked_down + + # original shape [B, S, H] + return expert_out.view(B, S, H), router_logits + + def blocked_ffn_forward_block_weights(self, hidden: torch.Tensor): + B, S, H = hidden.shape + T = B * S + hidden = hidden.view(T, H) + + # Router computation + router_logits = F.linear(hidden, self.router.weight, self.router.bias) + + # Top-k selection + top_w, top_i = torch.topk(router_logits, self.router.top_k, dim=-1) # both [T, K] + top_w = torch.nn.functional.softmax(top_w, dim=1, dtype=top_w.dtype) + + masked_logits = torch.zeros_like(router_logits) + masked_logits.scatter_(1, top_i, top_w) + + # Routing weights for each expert [T, E] + routing_weights = masked_logits + + # ────────────────── allocate the output tensor ───── + expert_out = hidden.new_zeros((T, H)) # accumulation buffer + target_blocks = int(os.environ.get("NUM_BLOCKS", 1)) + block_positions = [] + for j in range(target_blocks): + block_positions.append(j * (T // target_blocks)) + # ───────────────────────── Expert computation loop ───────────────────────────── + for e in range(self.experts.num_experts): + routing_weight = routing_weights[:, e].unsqueeze(-1) # [T, 1] + + W_g, W_u = self.experts.gate_proj[e], self.experts.up_proj[e] # [H, I], [H, I] + b_g, b_u = self.experts.gate_proj_bias[e], self.experts.up_proj_bias[e] # [I], [I] + W_d = self.experts.down_proj[e] # [I, H] + b_d = self.experts.down_proj_bias[e] # [H] + + block_count = 0 + outs = [] + for block_idx in range(target_blocks): + block_count += 1 + qi = block_positions[block_idx] + + # Calculate block size (last block should be handled with remainder) + if block_idx == target_blocks - 1: + real_q_len = T - qi + else: + real_q_len = block_positions[block_idx + 1] - qi + + tgb = hidden[qi : qi + real_q_len, :] + # Gate and Up projections + + wg_col_shape = W_g.shape[1] + wg_num_blocks = math.ceil(wg_col_shape / 128) + last_block_size = wg_col_shape % 128 if wg_col_shape % 128 != 0 else 128 + + intermediates = [] + for i in range(wg_num_blocks): + if i == wg_num_blocks - 1: + cur_gate = (tgb @ W_g[:, -last_block_size:]) + b_g[-last_block_size:] + cur_up = (tgb @ W_u[:, -last_block_size:]) + b_u[-last_block_size:] + else: + cur_gate = (tgb @ W_g[:, i * 128 : (i + 1) * 128]) + b_g[i * 128 : (i + 1) * 128] + cur_up = (tgb @ W_u[:, i * 128 : (i + 1) * 128]) + b_u[i * 128 : (i + 1) * 128] + + cur_gate = cur_gate.clamp(min=None, max=self.experts.limit) + cur_up = cur_up.clamp(min=-self.experts.limit, max=self.experts.limit) + cur_glu = cur_gate * torch.sigmoid(cur_gate * self.experts.alpha) + cur_intermediate = (cur_up + 1) * cur_glu + intermediates.append(cur_intermediate) + + intermediate = torch.cat(intermediates, dim=-1) + + downs = [] + for i in range(wg_num_blocks): + if i == wg_num_blocks - 1: + downs.append((intermediate @ W_d[:, -last_block_size:]) + b_d[-last_block_size:]) + else: + downs.append((intermediate @ W_d[:, i * 128 : (i + 1) * 128]) + b_d[i * 128 : (i + 1) * 128]) + + down_out_block = torch.cat(downs, dim=1) + outs.append(down_out_block) + + down_out = torch.cat(outs, dim=0) + + # Apply routing weights and accumulate + masked_down = torch.where(routing_weight > 0, down_out * routing_weight, torch.zeros_like(expert_out)) + expert_out += masked_down + + # original shape [B, S, H] + return expert_out.view(B, S, H), router_logits + class QEffGptOssMLP(GptOssMLP): # ------------------- Gather based, weights as activation approach --------------- @@ -146,7 +308,6 @@ def forward_weights_as_activation(self, hidden_states): # ------------------- Gather based, weights as activation approach, With Seperate Gate, up Projections --------------- def forward(self, hidden_states): - # print("Seperate Split, Up, Gate Projections") bs, seq_len, _ = hidden_states.shape hidden_states = hidden_states.view(bs * seq_len, self.experts.hidden_size) @@ -417,7 +578,6 @@ def eager_attention_forward_blocked( scaling: float, **kwargs, ): - softmax_count = 0 key_states = repeat_kv(key, module.num_key_value_groups) value_states = repeat_kv(value, module.num_key_value_groups) @@ -426,9 +586,6 @@ def eager_attention_forward_blocked( block_positions = [] for j in range(target_blocks): block_positions.append(j * (CL // target_blocks)) - - print(f"CL={CL}, target_blocks={target_blocks}") - block_count = 0 outs = [] for block_idx in range(target_blocks): @@ -458,7 +615,6 @@ def eager_attention_forward_blocked( outs.append(out_block) output = torch.cat(outs, dim=2) - print(f"Completed {block_count} blocks, {softmax_count} softmax operations") output = output.view(BS, NH, CL, DH).transpose(1, 2).contiguous() return output, output @@ -487,8 +643,9 @@ def forward( hidden_shape = (*input_shape, -1, self.head_dim) key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2) value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) - - cos, sin = self.rotary_emb(value_states, seq_len=getattr(self.config, "max_seq_len_cached", 32 * 1024)) + if not (max_seq_len_cached := getattr(self.config, "max_seq_len_cached")): + max_seq_len_cached = 32 * 1024 + cos, sin = self.rotary_emb(value_states, seq_len=max_seq_len_cached) query_states, key_states = qeff_apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) if past_key_value is not None: @@ -565,8 +722,9 @@ def forward( query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2) key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2) value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) - - cos, sin = self.rotary_emb(value_states, seq_len=getattr(self.config, "max_seq_len_cached", 32 * 1024)) + if not (max_seq_len_cached := getattr(self.config, "max_seq_len_cached")): + max_seq_len_cached = 32 * 1024 + cos, sin = self.rotary_emb(value_states, seq_len=max_seq_len_cached) query_states, key_states = qeff_apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) if past_key_value is not None: diff --git a/QEfficient/transformers/models/modeling_auto.py b/QEfficient/transformers/models/modeling_auto.py index 182212049..8f7fed4c0 100644 --- a/QEfficient/transformers/models/modeling_auto.py +++ b/QEfficient/transformers/models/modeling_auto.py @@ -2587,30 +2587,38 @@ def export( str Path to the generated ONNX graph file. """ + bs: int = constants.ONNX_EXPORT_EXAMPLE_BATCH_SIZE + seq_len: int = constants.ONNX_EXPORT_EXAMPLE_SEQ_LEN + fbs: int = constants.ONNX_EXPORT_EXAMPLE_FBS + kv_cache_shape = get_padding_shape_from_config( + self.model.config, fbs if self.continuous_batching else bs, seq_len + ) if prefill_only: - block_size = os.environ.get("BLOCK_SIZE", None) - if block_size is None: - block_size = 128 - logger.warning( - "Setting BLOCK_SIZE=128 for prefill_only model, please set ENV variable `BLOCK_SIZE` to override" - ) - if prefill_seq_len is None or prefill_seq_len % block_size != 0: - raise ValueError( - f"When prefill_only=True, 'prefill_seq_len' must be explicitly set and divisible by block_size={block_size}. " - f"Received: prefill_seq_len={prefill_seq_len}" - ) + assert not self.continuous_batching, "prefill_only=True is not supported with continuous_batching=True" - os.environ["NUM_BLOCKS"] = str(prefill_seq_len // block_size) if self.model.config.model_type in SPECIALIZED_PREFILL_ONLY_MODEL_ARCH: + block_size = os.environ.get("BLOCK_SIZE", None) + if block_size is None: + block_size = 128 + logger.warning( + "Setting BLOCK_SIZE=128 for prefill_only model, please set ENV variable `BLOCK_SIZE` to override" + ) + if prefill_seq_len is None or prefill_seq_len % block_size != 0: + raise ValueError( + f"When prefill_only=True, 'prefill_seq_len' must be explicitly set and divisible by block_size={block_size}. " + f"Received: prefill_seq_len={prefill_seq_len}" + ) + os.environ["NUM_BLOCKS"] = str(prefill_seq_len // block_size) + self.prefill(True) + self.hash_params["prefill_only"] = True + self.hash_params["num_blocks"] = os.environ["NUM_BLOCKS"] + seq_len = prefill_seq_len // block_size if (prefill_seq_len // block_size) > seq_len else seq_len else: self.prefill(False) - bs: int = constants.ONNX_EXPORT_EXAMPLE_BATCH_SIZE - seq_len: int = constants.ONNX_EXPORT_EXAMPLE_SEQ_LEN if not prefill_only else prefill_seq_len // block_size - fbs: int = constants.ONNX_EXPORT_EXAMPLE_FBS - kv_cache_shape = get_padding_shape_from_config( - self.model.config, fbs if self.continuous_batching else bs, seq_len - ) + self.hash_params.pop("prefill_only", None) + self.hash_params.pop("num_blocks", None) + example_inputs = { "input_ids": torch.zeros((bs, seq_len), dtype=torch.int64), "position_ids": torch.arange(seq_len, dtype=torch.int64).view(1, seq_len).repeat(bs, 1), diff --git a/QEfficient/utils/hash_utils.py b/QEfficient/utils/hash_utils.py index 2c93e3990..948b72e6a 100644 --- a/QEfficient/utils/hash_utils.py +++ b/QEfficient/utils/hash_utils.py @@ -68,6 +68,5 @@ def create_export_hash(**kwargs): export_hash_params.update(onnx_transform_kwargs) if export_hash_params.get("peft_config") is not None and not isinstance(export_hash_params["peft_config"], dict): export_hash_params["peft_config"] = export_hash_params["peft_config"].to_dict() - export_hash_params["prefill_only"] = kwargs.get("prefill_only") return hash_dict_params(export_hash_params), export_hash_params diff --git a/scripts/Jenkinsfile b/scripts/Jenkinsfile index 134770638..4a9f2c33e 100644 --- a/scripts/Jenkinsfile +++ b/scripts/Jenkinsfile @@ -41,7 +41,7 @@ pipeline { mkdir -p $PWD/Non_cli_qaic && export TOKENIZERS_PARALLELISM=false && export QEFF_HOME=$PWD/Non_cli_qaic && - pytest tests -m '(not cli) and (not on_qaic) and (not finetune)' --ignore tests/vllm --junitxml=tests/tests_log1.xml && + pytest tests -m '(not cli) and (not on_qaic) and (not finetune)' --ignore tests/vllm -n 4 --junitxml=tests/tests_log1.xml && junitparser merge tests/tests_log1.xml tests/tests_log.xml && deactivate" ''' diff --git a/tests/peft/test_peft_model.py b/tests/peft/test_peft_model.py index cc94467db..c3bb2f140 100644 --- a/tests/peft/test_peft_model.py +++ b/tests/peft/test_peft_model.py @@ -178,9 +178,9 @@ def test_auto_peft_model_for_causal_lm_activate_invalid(base_config, adapter_con def test_auto_peft_model_for_causal_lm_compile_generate(base_config, adapter_config, batch_size, tmp_path): _, lora_model = create_peft_model(base_config, adapter_config) qeff_model = QEffAutoPeftModelForCausalLM(lora_model) - qeff_model.export(tmp_path) + onnx_path = qeff_model.export(tmp_path) start = perf_counter() - qeff_model.compile(batch_size=batch_size, prefill_seq_len=32, ctx_len=128) + qeff_model.compile(onnx_path=onnx_path, batch_size=batch_size, prefill_seq_len=32, ctx_len=128) end = perf_counter() compile_time_0 = end - start @@ -197,7 +197,7 @@ def test_auto_peft_model_for_causal_lm_compile_generate(base_config, adapter_con ) start = perf_counter() - qeff_model.compile(batch_size=batch_size, prefill_seq_len=32, ctx_len=128) + qeff_model.compile(onnx_path=onnx_path, batch_size=batch_size, prefill_seq_len=32, ctx_len=128) end = perf_counter() compile_time_1 = end - start assert compile_time_1 < 0.01 * compile_time_0 diff --git a/tests/transformers/test_causal_lm.py b/tests/transformers/test_causal_lm.py index 5e5ad4b5d..925af8b3a 100644 --- a/tests/transformers/test_causal_lm.py +++ b/tests/transformers/test_causal_lm.py @@ -17,7 +17,7 @@ from QEfficient.utils import constants, get_padding_shape_from_config from QEfficient.utils.hash_utils import hash_dict_params -configs = [ +test_configs = [ # name, max_position_embeddings, num_hidden_layers, num_attention_heads, hidden_size, intermediate_size, vocab_size, additional_params ("gpt2", 256, 2, 4, 128, 512, 127, {}), ("codegen", 256, 2, 4, 128, 512, 127, {"rotary_dim": 16}), @@ -36,30 +36,43 @@ ("gpt_oss", 256, 3, 4, 128, 512, 127, {"num_key_value_heads": 2}), ] -configs = [ - AutoConfig.for_model( - model_name, - max_position_embeddings=max_position_embeddings, - num_hidden_layers=num_hidden_layers, - num_attention_heads=num_attention_heads, - hidden_size=hidden_size, - intermediate_size=intermediate_size, - vocab_size=vocab_size, - **additional_params, - ) - for ( - model_name, - max_position_embeddings, - num_hidden_layers, - num_attention_heads, - hidden_size, - intermediate_size, - vocab_size, - additional_params, - ) in configs +test_prefill_only_specialized_models_configs = [ + ("gpt_oss", 256, 2, 2, 32, 32, 127, {"num_key_value_heads": 2}), ] + + +def get_auto_config_from_test_config(configs): + auto_configs = [ + AutoConfig.for_model( + model_name, + max_position_embeddings=max_position_embeddings, + num_hidden_layers=num_hidden_layers, + num_attention_heads=num_attention_heads, + hidden_size=hidden_size, + intermediate_size=intermediate_size, + vocab_size=vocab_size, + **additional_params, + ) + for ( + model_name, + max_position_embeddings, + num_hidden_layers, + num_attention_heads, + hidden_size, + intermediate_size, + vocab_size, + additional_params, + ) in configs + ] + return auto_configs + + +configs = get_auto_config_from_test_config(test_configs) config_ids = [x.model_type for x in configs] +prefill_only_configs = get_auto_config_from_test_config(test_prefill_only_specialized_models_configs) +prefill_only_config_ids = [x.model_type for x in prefill_only_configs] + model_kwargs = {"attn_implementation": "eager"} @@ -158,7 +171,6 @@ def test_causal_lm_hash_creation(config, cb, tmp_path): hash_params["qaic_config"] = None # Create parameters separately for hash creation - bs: int = constants.ONNX_EXPORT_EXAMPLE_BATCH_SIZE seq_len: int = constants.ONNX_EXPORT_EXAMPLE_SEQ_LEN fbs: int = constants.ONNX_EXPORT_EXAMPLE_FBS @@ -205,12 +217,29 @@ def test_causal_lm_hash_creation(config, cb, tmp_path): export_params["output_names"] = output_names export_params["dynamic_axes"] = dynamic_axes hash_params["export_params"] = export_params - hash_params["prefill_only"] = False manual_hash = hash_dict_params(hash_params) assert manual_hash == qeff_model.export_hash +@pytest.mark.parametrize("cb", [False, True], ids=["nocb", "cb"]) +@pytest.mark.parametrize("config", prefill_only_configs, ids=prefill_only_config_ids) +def test_prefill_only_specialized_models(config, cb, tmp_path): + model = AutoModelForCausalLM.from_config(config, **model_kwargs) + qeff_model = QEFFAutoModelForCausalLM(model, cb) + if cb: + with pytest.raises(AssertionError): + qeff_model.export(tmp_path, prefill_only=True, offload_pt_weights=False) + else: + with pytest.raises(ValueError): + qeff_model.export(tmp_path, prefill_only=True, offload_pt_weights=False) + qeff_model.export(tmp_path, prefill_only=True, prefill_seq_len=256, offload_pt_weights=False) + first_export_hash = qeff_model.export_hash + qeff_model.export(tmp_path, prefill_only=False, offload_pt_weights=False) + second_export_hash = qeff_model.export_hash + assert first_export_hash != second_export_hash + + @pytest.fixture def tmp_cache(tmp_path, monkeypatch): monkeypatch.setattr("QEfficient.utils._utils.QEFF_HOME", tmp_path) From 4d4639e6feff7a0cf4e4981fbd31da4cfd0fb4d0 Mon Sep 17 00:00:00 2001 From: Onkar Chougule Date: Mon, 10 Nov 2025 07:30:56 +0000 Subject: [PATCH 21/37] fixed tests Signed-off-by: Onkar Chougule --- tests/peft/lora/test_lora_model.py | 4 ++-- tests/transformers/models/test_disagg_mode.py | 4 +--- 2 files changed, 3 insertions(+), 5 deletions(-) diff --git a/tests/peft/lora/test_lora_model.py b/tests/peft/lora/test_lora_model.py index 00a4216b7..46b33c60b 100644 --- a/tests/peft/lora/test_lora_model.py +++ b/tests/peft/lora/test_lora_model.py @@ -222,7 +222,7 @@ def test_auto_lora_model_for_causal_lm_noncb_export_compile_generate( # export start = perf_counter() - qeff_model.export(export_dir=tmp_path) + onnx_path = qeff_model.export(export_dir=tmp_path) end = perf_counter() export_time_0 = end - start model_path = tmp_path.with_name(tmp_path.name + "-" + qeff_model.export_hash) @@ -237,7 +237,7 @@ def test_auto_lora_model_for_causal_lm_noncb_export_compile_generate( assert export_time_1 < export_time_0 # test compile - qeff_model.compile(prefill_seq_len=32, ctx_len=64) + qeff_model.compile(onnx_path=onnx_path, prefill_seq_len=32, ctx_len=64) assert Path(qeff_model.qpc_path).is_dir() assert os.path.isfile(os.path.join(os.path.dirname(qeff_model.qpc_path), "qconfig.json")) diff --git a/tests/transformers/models/test_disagg_mode.py b/tests/transformers/models/test_disagg_mode.py index 0e303d389..67ee48944 100644 --- a/tests/transformers/models/test_disagg_mode.py +++ b/tests/transformers/models/test_disagg_mode.py @@ -56,9 +56,7 @@ def test_disagg_mode_prefill(model_id, prompt): undo_transformers_quantizers() - qeff_model = QEFFAutoModelForCausalLM.from_pretrained( - model_id, num_hidden_layers=2, max_position_embeddings=64 * 1024 - ) + qeff_model = QEFFAutoModelForCausalLM.from_pretrained(model_id, num_hidden_layers=2) qeff_model.prefill(True) config = qeff_model.model.config inputs = tokenizer(prompt, return_tensors="np", padding="max_length", max_length=padded_len) From f32df629a8432cb1011bc4e7f4547daf660c9b1b Mon Sep 17 00:00:00 2001 From: Onkar Chougule Date: Mon, 10 Nov 2025 11:08:13 +0000 Subject: [PATCH 22/37] fixed kv cache shape Signed-off-by: Onkar Chougule --- QEfficient/transformers/models/modeling_auto.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/QEfficient/transformers/models/modeling_auto.py b/QEfficient/transformers/models/modeling_auto.py index 8f7fed4c0..4dcb4c5d6 100644 --- a/QEfficient/transformers/models/modeling_auto.py +++ b/QEfficient/transformers/models/modeling_auto.py @@ -2590,9 +2590,6 @@ def export( bs: int = constants.ONNX_EXPORT_EXAMPLE_BATCH_SIZE seq_len: int = constants.ONNX_EXPORT_EXAMPLE_SEQ_LEN fbs: int = constants.ONNX_EXPORT_EXAMPLE_FBS - kv_cache_shape = get_padding_shape_from_config( - self.model.config, fbs if self.continuous_batching else bs, seq_len - ) if prefill_only: assert not self.continuous_batching, "prefill_only=True is not supported with continuous_batching=True" @@ -2619,6 +2616,9 @@ def export( self.hash_params.pop("prefill_only", None) self.hash_params.pop("num_blocks", None) + kv_cache_shape = get_padding_shape_from_config( + self.model.config, fbs if self.continuous_batching else bs, seq_len + ) example_inputs = { "input_ids": torch.zeros((bs, seq_len), dtype=torch.int64), "position_ids": torch.arange(seq_len, dtype=torch.int64).view(1, seq_len).repeat(bs, 1), From 0b29ba4589d34a242be3bcb7b3bef83a4aec9fcd Mon Sep 17 00:00:00 2001 From: Onkar Chougule <168134249+ochougul@users.noreply.github.com> Date: Tue, 11 Nov 2025 15:09:28 +0530 Subject: [PATCH 23/37] fixed self.onnx_path issue in modeling_qeff Signed-off-by: Onkar Chougule <168134249+ochougul@users.noreply.github.com> --- QEfficient/base/modeling_qeff.py | 29 +++++++++++++++++------------ 1 file changed, 17 insertions(+), 12 deletions(-) diff --git a/QEfficient/base/modeling_qeff.py b/QEfficient/base/modeling_qeff.py index 33565c49d..b0fe65d03 100644 --- a/QEfficient/base/modeling_qeff.py +++ b/QEfficient/base/modeling_qeff.py @@ -327,6 +327,22 @@ def _export( self.onnx_path = onnx_path return onnx_path + def get_onnx_path(self, prefill_only: Optional[bool] = False, + specializations: Optional[List[Dict[str, int]]] = None, + offload_pt_weights: Optional[bool] = True, + use_onnx_subfunctions: Optional[bool] = False): + kwargs = {"offload_pt_weights": offload_pt_weights, + "use_onnx_subfunctions": use_onnx_subfunctions} + if prefill_only: + if self.prefill_onnx_path is None: + kwargs.update({"prefill_only": prefill_only, "prefill_seq_len": specializations[0].get("seq_len")}) + self.export(**kwargs) + return self.prefill_onnx_path + else: + if self.onnx_path is None: + self.export(**kwargs) + return self.onnx_path + @dump_qconfig def _compile( self, @@ -368,18 +384,7 @@ def _compile( For QNN Compilation path, when enable_qnn is set to True, any parameter passed in compiler_options will be ignored. """ - kwargs = {"offload_pt_weights": offload_pt_weights, - "use_onnx_subfunctions": use_onnx_subfunctions} - if prefill_only and self.prefill_onnx_path is None: - kwargs.update({"prefill_only": prefill_only, "prefill_seq_len": specializations[0].get("seq_len")}) - self.export(**kwargs) - onnx_path = Path(self.prefill_onnx_path) - elif onnx_path is None: - self.export(**kwargs) - onnx_path = Path(self.onnx_path) - else: - onnx_path = Path(onnx_path) - + onnx_path = Path(onnx_path if onnx_path else self.get_onnx_path(prefill_only, specializations, offload_pt_weights, use_onnx_subfunctions)) compile_dir = Path(compile_dir or onnx_path.parent) qpc_path = compile_dir / "qpc" if not onnx_path.is_file(): From 053acaad18de121f8b0c93846a07f2670ac8fcf6 Mon Sep 17 00:00:00 2001 From: Onkar Chougule Date: Thu, 13 Nov 2025 10:29:09 +0000 Subject: [PATCH 24/37] added ffn blocking and num blocks env variables Signed-off-by: Onkar Chougule --- .../models/gpt_oss/modeling_gpt_oss.py | 4 +- .../transformers/models/modeling_auto.py | 53 ++++++++++++------- 2 files changed, 37 insertions(+), 20 deletions(-) diff --git a/QEfficient/transformers/models/gpt_oss/modeling_gpt_oss.py b/QEfficient/transformers/models/gpt_oss/modeling_gpt_oss.py index 7a8120bfc..228a0c677 100644 --- a/QEfficient/transformers/models/gpt_oss/modeling_gpt_oss.py +++ b/QEfficient/transformers/models/gpt_oss/modeling_gpt_oss.py @@ -47,6 +47,8 @@ def __qeff_init__(self): class QEffPrefillOnlyGptOssMLP(GptOssMLP): def forward(self, hidden: torch.Tensor): + if os.environ.get("NUM_FFN_BLOCKS", None) is not None: + return self.blocked_ffn_forward(hidden) B, S, H = hidden.shape T = B * S hidden = hidden.view(T, H) @@ -118,7 +120,7 @@ def blocked_ffn_forward(self, hidden: torch.Tensor): # ────────────────── allocate the output tensor ───── expert_out = hidden.new_zeros((T, H)) # accumulation buffer - target_blocks = int(os.environ.get("NUM_BLOCKS", 1)) + target_blocks = int(os.environ.get("NUM_FFN_BLOCKS", 1)) block_positions = [] for j in range(target_blocks): block_positions.append(j * (T // target_blocks)) diff --git a/QEfficient/transformers/models/modeling_auto.py b/QEfficient/transformers/models/modeling_auto.py index 4dcb4c5d6..9db2ab596 100644 --- a/QEfficient/transformers/models/modeling_auto.py +++ b/QEfficient/transformers/models/modeling_auto.py @@ -2560,6 +2560,35 @@ def get_model_config(self) -> dict: """ return self.model.config.__dict__ + def get_seq_len_and_handle_specialized_prefill_model(self, prefill_seq_len: Optional[int] = None) -> int: + num_q_blocks = os.environ.get("NUM_Q_BLOCKS", None) + if num_q_blocks is None: + block_size = 128 + if prefill_seq_len is None or prefill_seq_len % block_size != 0 or prefill_seq_len < 128: + raise ValueError( + f"When prefill_only=True, 'prefill_seq_len' must be explicitly set and divisible by block_size={block_size}. " + f"Or set `NUM_BLOCKS` ENV variable" + f"Received: prefill_seq_len={prefill_seq_len}" + ) + + num_q_blocks = prefill_seq_len // block_size + logger.warning( + f"Setting NUM_BLOCKS={num_q_blocks} used in attention Q-blocking for prefill_only model, please set ENV variable `NUM_BLOCKS` to override" + ) + os.environ["NUM_Q_BLOCKS"] = num_q_blocks + + num_ffn_blocks = os.environ.get("NUM_FFN_BLOCKS", None) + min_seq_len = int(max(num_q_blocks, num_ffn_blocks)) if num_ffn_blocks else num_q_blocks + + self.prefill(True) + self.hash_params["prefill_only"] = True + self.hash_params["num_blocks"] = os.environ["NUM_BLOCKS"] + return ( + min_seq_len + if min_seq_len > constants.ONNX_EXPORT_EXAMPLE_SEQ_LEN + else constants.ONNX_EXPORT_EXAMPLE_SEQ_LEN + ) + def export( self, export_dir: Optional[str] = None, @@ -2592,25 +2621,11 @@ def export( fbs: int = constants.ONNX_EXPORT_EXAMPLE_FBS if prefill_only: assert not self.continuous_batching, "prefill_only=True is not supported with continuous_batching=True" - - if self.model.config.model_type in SPECIALIZED_PREFILL_ONLY_MODEL_ARCH: - block_size = os.environ.get("BLOCK_SIZE", None) - if block_size is None: - block_size = 128 - logger.warning( - "Setting BLOCK_SIZE=128 for prefill_only model, please set ENV variable `BLOCK_SIZE` to override" - ) - if prefill_seq_len is None or prefill_seq_len % block_size != 0: - raise ValueError( - f"When prefill_only=True, 'prefill_seq_len' must be explicitly set and divisible by block_size={block_size}. " - f"Received: prefill_seq_len={prefill_seq_len}" - ) - os.environ["NUM_BLOCKS"] = str(prefill_seq_len // block_size) - - self.prefill(True) - self.hash_params["prefill_only"] = True - self.hash_params["num_blocks"] = os.environ["NUM_BLOCKS"] - seq_len = prefill_seq_len // block_size if (prefill_seq_len // block_size) > seq_len else seq_len + seq_len = ( + self.get_seq_len_and_handle_specialized_prefill_model(prefill_seq_len) + if self.model.config.model_type in SPECIALIZED_PREFILL_ONLY_MODEL_ARCH + else seq_len + ) else: self.prefill(False) self.hash_params.pop("prefill_only", None) From 8447c183e38ab366aa30b20bfdc8f53f22b56abd Mon Sep 17 00:00:00 2001 From: Onkar Chougule Date: Mon, 17 Nov 2025 07:50:02 +0000 Subject: [PATCH 25/37] include num_ffn_blocks in hash Signed-off-by: Onkar Chougule --- QEfficient/transformers/models/modeling_auto.py | 14 +++++++++++--- 1 file changed, 11 insertions(+), 3 deletions(-) diff --git a/QEfficient/transformers/models/modeling_auto.py b/QEfficient/transformers/models/modeling_auto.py index 9db2ab596..b49a60ad5 100644 --- a/QEfficient/transformers/models/modeling_auto.py +++ b/QEfficient/transformers/models/modeling_auto.py @@ -2575,14 +2575,22 @@ def get_seq_len_and_handle_specialized_prefill_model(self, prefill_seq_len: Opti logger.warning( f"Setting NUM_BLOCKS={num_q_blocks} used in attention Q-blocking for prefill_only model, please set ENV variable `NUM_BLOCKS` to override" ) - os.environ["NUM_Q_BLOCKS"] = num_q_blocks + os.environ["NUM_Q_BLOCKS"] = str(num_q_blocks) + num_q_blocks = int(num_q_blocks) num_ffn_blocks = os.environ.get("NUM_FFN_BLOCKS", None) - min_seq_len = int(max(num_q_blocks, num_ffn_blocks)) if num_ffn_blocks else num_q_blocks + num_ffn_blocks = int(num_ffn_blocks) if num_ffn_blocks else num_ffn_blocks + min_seq_len = max(num_q_blocks, num_ffn_blocks) if num_ffn_blocks else num_q_blocks + if (num_ffn_blocks and min_seq_len % num_ffn_blocks != 0) or min_seq_len % num_q_blocks != 0: + raise ValueError( + f"Got NUM_FFN_BLOCKS={num_ffn_blocks} and NUM_Q_BLOCKS={num_q_blocks}, tried to set seq_len={min_seq_len} for export but," + "seq_len is not divisible by either num_ffn_blocks or num_q_blocks, try chaning the values." + ) self.prefill(True) self.hash_params["prefill_only"] = True - self.hash_params["num_blocks"] = os.environ["NUM_BLOCKS"] + self.hash_params["num_blocks"] = num_q_blocks + self.hash_params["num_ffn_blocks"] = num_ffn_blocks return ( min_seq_len if min_seq_len > constants.ONNX_EXPORT_EXAMPLE_SEQ_LEN From 3982c9d364afe793812531b65f507ac8186835a9 Mon Sep 17 00:00:00 2001 From: Onkar Chougule Date: Tue, 18 Nov 2025 10:23:05 +0000 Subject: [PATCH 26/37] fixed dynamic range in case of subfunc issue and nonmatching ctx, prefill seq_len for prefill_only gpt_oss model Signed-off-by: Onkar Chougule --- QEfficient/transformers/models/gpt_oss/modeling_gpt_oss.py | 2 +- QEfficient/transformers/models/modeling_auto.py | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/QEfficient/transformers/models/gpt_oss/modeling_gpt_oss.py b/QEfficient/transformers/models/gpt_oss/modeling_gpt_oss.py index 228a0c677..8cf23911d 100644 --- a/QEfficient/transformers/models/gpt_oss/modeling_gpt_oss.py +++ b/QEfficient/transformers/models/gpt_oss/modeling_gpt_oss.py @@ -663,7 +663,7 @@ def forward( } if self.sliding_window is not None: sliding_window_len = past_key_value.sliding_window_len - short_read_idx = torch.arange(sliding_window_len) + short_read_idx = torch.arange(past_key_value.key_cache[self.layer_idx].shape[2]) read_idx = short_read_idx + torch.where( position_ids.max() > sliding_window_len - 1, position_ids.max() - sliding_window_len + 1, 0 ) diff --git a/QEfficient/transformers/models/modeling_auto.py b/QEfficient/transformers/models/modeling_auto.py index b49a60ad5..cd3fc9f6a 100644 --- a/QEfficient/transformers/models/modeling_auto.py +++ b/QEfficient/transformers/models/modeling_auto.py @@ -3136,6 +3136,7 @@ def compile( kv_cache_batch_size=kv_cache_batch_size, full_batch_size=full_batch_size, num_speculative_tokens=num_speculative_tokens, + prefill_only=prefill_only, ) if decode_spec: specializations.append(decode_spec) From 4c38de3baad1e481d427e0b9a1a4569b8f9c6c68 Mon Sep 17 00:00:00 2001 From: Onkar Chougule Date: Tue, 18 Nov 2025 16:47:10 +0000 Subject: [PATCH 27/37] added swa optimization for reducing MACCs using less KV Signed-off-by: Onkar Chougule --- .../models/gpt_oss/modeling_gpt_oss.py | 75 ++++++++++++++++++- .../transformers/models/modeling_auto.py | 4 +- 2 files changed, 73 insertions(+), 6 deletions(-) diff --git a/QEfficient/transformers/models/gpt_oss/modeling_gpt_oss.py b/QEfficient/transformers/models/gpt_oss/modeling_gpt_oss.py index 8cf23911d..efdccf8ec 100644 --- a/QEfficient/transformers/models/gpt_oss/modeling_gpt_oss.py +++ b/QEfficient/transformers/models/gpt_oss/modeling_gpt_oss.py @@ -83,7 +83,7 @@ def forward(self, hidden: torch.Tensor): up = (hidden @ W_u) + b_u # [T, I] # Apply GptOss activation with clamping - gate = gate.clamp(min=None, max=self.experts.limit) + gate = gate.clamp(min=torch.finfo(torch.float16).min, max=self.experts.limit) up = up.clamp(min=-self.experts.limit, max=self.experts.limit) # GLU activation @@ -584,11 +584,12 @@ def eager_attention_forward_blocked( value_states = repeat_kv(value, module.num_key_value_groups) BS, NH, CL, DH = query.shape - target_blocks = int(os.environ.get("NUM_BLOCKS", 1)) + target_blocks = int(os.environ.get("NUM_Q_BLOCKS", 1)) block_positions = [] for j in range(target_blocks): block_positions.append(j * (CL // target_blocks)) block_count = 0 + outs = [] for block_idx in range(target_blocks): block_count += 1 @@ -621,6 +622,69 @@ def eager_attention_forward_blocked( return output, output +def opt_eager_attention_forward_blocked( + module: nn.Module, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: Optional[torch.Tensor], + scaling: float, + **kwargs, +): + key_states = repeat_kv(key, module.num_key_value_groups) + value_states = repeat_kv(value, module.num_key_value_groups) + + BS, NH, CL, DH = query.shape + target_blocks = int(os.environ.get("NUM_Q_BLOCKS", 1)) + block_positions = [] + for j in range(target_blocks): + block_positions.append(j * (CL // target_blocks)) + block_count = 0 + outs = [] + for block_idx in range(target_blocks): + block_count += 1 + qi = block_positions[block_idx] + # Calculate block size (last block should be handled with remainder) + + if block_idx == target_blocks - 1: + real_q_len = CL - qi + else: + real_q_len = block_positions[block_idx + 1] - qi + + if block_idx == 0: + kv_start_idx = 0 + else: + kv_start_idx = qi - 128 + + q_block = query[:, :, qi : qi + real_q_len, :] + if kwargs.get("sliding_window"): + k_block = key_states[:, :, kv_start_idx : qi + real_q_len, :] + v_block = value_states[:, :, kv_start_idx : qi + real_q_len, :] + attn_mask_block = attention_mask[:, :, qi : qi + real_q_len, kv_start_idx : qi + real_q_len] + else: + k_block = key_states + v_block = value_states + attn_mask_block = attention_mask[:, :, qi : qi + real_q_len, :] + + scores = torch.matmul(q_block, k_block.transpose(2, 3)) * scaling + curr_attn_weights = torch.where( + attn_mask_block, torch.tensor(MIN_MASKED_ATTENTION_VALUE, dtype=torch.float32), scores + ) + sinks = module.sinks.reshape(1, -1, 1, 1).expand( + curr_attn_weights.shape[0], -1, curr_attn_weights.shape[-2], -1 + ) + combined_logits = torch.cat([curr_attn_weights, sinks], dim=-1) + combined_logits = combined_logits - combined_logits.max(dim=-1, keepdim=True).values + curr_attn_weights = nn.functional.softmax(combined_logits, dim=-1, dtype=torch.float32) + curr_attn_weights = curr_attn_weights[..., :-1] + out_block = torch.matmul(curr_attn_weights, v_block) + outs.append(out_block) + output = torch.cat(outs, dim=2) + + output = output.view(BS, NH, CL, DH).transpose(1, 2).contiguous() + return output, output + + class QEffPrefillOnlyGptOssAttention(GptOssAttention): """Multi-headed attention from 'Attention Is All You Need' paper""" @@ -667,7 +731,7 @@ def forward( read_idx = short_read_idx + torch.where( position_ids.max() > sliding_window_len - 1, position_ids.max() - sliding_window_len + 1, 0 ) - # This is a trick to export with NUM_BLOCKS position_ids.max(), 0, read_idx) k_cache = key_states[:, :, read_idx, :] v_cache = value_states[:, :, read_idx, :] @@ -680,7 +744,10 @@ def forward( else: attention_mask = attention_mask - attention_interface: Callable = eager_attention_forward_blocked + if os.environ.get("ENABLE_OPT_SWA", "0") == "1": + attention_interface: Callable = opt_eager_attention_forward_blocked + else: + attention_interface: Callable = eager_attention_forward_blocked attn_output, attn_weights = attention_interface( self, query_states, diff --git a/QEfficient/transformers/models/modeling_auto.py b/QEfficient/transformers/models/modeling_auto.py index cd3fc9f6a..5558563ca 100644 --- a/QEfficient/transformers/models/modeling_auto.py +++ b/QEfficient/transformers/models/modeling_auto.py @@ -2567,13 +2567,13 @@ def get_seq_len_and_handle_specialized_prefill_model(self, prefill_seq_len: Opti if prefill_seq_len is None or prefill_seq_len % block_size != 0 or prefill_seq_len < 128: raise ValueError( f"When prefill_only=True, 'prefill_seq_len' must be explicitly set and divisible by block_size={block_size}. " - f"Or set `NUM_BLOCKS` ENV variable" + f"Or set `NUM_Q_BLOCKS` ENV variable" f"Received: prefill_seq_len={prefill_seq_len}" ) num_q_blocks = prefill_seq_len // block_size logger.warning( - f"Setting NUM_BLOCKS={num_q_blocks} used in attention Q-blocking for prefill_only model, please set ENV variable `NUM_BLOCKS` to override" + f"Setting NUM_Q_BLOCKS={num_q_blocks} used in attention Q-blocking for prefill_only model, please set ENV variable `NUM_Q_BLOCKS` to override" ) os.environ["NUM_Q_BLOCKS"] = str(num_q_blocks) num_q_blocks = int(num_q_blocks) From 5e5d70882b45ea4054215032163da16eed581787 Mon Sep 17 00:00:00 2001 From: Onkar Chougule Date: Mon, 24 Nov 2025 19:24:00 +0000 Subject: [PATCH 28/37] added opt swa to hash Signed-off-by: Onkar Chougule --- QEfficient/transformers/models/modeling_auto.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/QEfficient/transformers/models/modeling_auto.py b/QEfficient/transformers/models/modeling_auto.py index 5558563ca..2ba6d82bb 100644 --- a/QEfficient/transformers/models/modeling_auto.py +++ b/QEfficient/transformers/models/modeling_auto.py @@ -2589,8 +2589,9 @@ def get_seq_len_and_handle_specialized_prefill_model(self, prefill_seq_len: Opti self.prefill(True) self.hash_params["prefill_only"] = True - self.hash_params["num_blocks"] = num_q_blocks - self.hash_params["num_ffn_blocks"] = num_ffn_blocks + self.hash_params["NUM_Q_BLOCKS"] = num_q_blocks + self.hash_params["NUM_FFN_BLOCKS"] = num_ffn_blocks + self.hash_params["ENABLE_OPT_SWA"] = os.environ.get("ENABLE_OPT_SWA", "0") return ( min_seq_len if min_seq_len > constants.ONNX_EXPORT_EXAMPLE_SEQ_LEN From 88ae0bea278cda1ab8d31611d428d52072cb172b Mon Sep 17 00:00:00 2001 From: Onkar Chougule Date: Mon, 24 Nov 2025 21:12:33 +0000 Subject: [PATCH 29/37] lint and format Signed-off-by: Onkar Chougule --- QEfficient/base/modeling_qeff.py | 20 +++++++++++++------- QEfficient/utils/_utils.py | 1 - 2 files changed, 13 insertions(+), 8 deletions(-) diff --git a/QEfficient/base/modeling_qeff.py b/QEfficient/base/modeling_qeff.py index b0fe65d03..d4bbb0232 100644 --- a/QEfficient/base/modeling_qeff.py +++ b/QEfficient/base/modeling_qeff.py @@ -327,12 +327,14 @@ def _export( self.onnx_path = onnx_path return onnx_path - def get_onnx_path(self, prefill_only: Optional[bool] = False, - specializations: Optional[List[Dict[str, int]]] = None, - offload_pt_weights: Optional[bool] = True, - use_onnx_subfunctions: Optional[bool] = False): - kwargs = {"offload_pt_weights": offload_pt_weights, - "use_onnx_subfunctions": use_onnx_subfunctions} + def get_onnx_path( + self, + prefill_only: Optional[bool] = False, + specializations: Optional[List[Dict[str, int]]] = None, + offload_pt_weights: Optional[bool] = True, + use_onnx_subfunctions: Optional[bool] = False, + ): + kwargs = {"offload_pt_weights": offload_pt_weights, "use_onnx_subfunctions": use_onnx_subfunctions} if prefill_only: if self.prefill_onnx_path is None: kwargs.update({"prefill_only": prefill_only, "prefill_seq_len": specializations[0].get("seq_len")}) @@ -384,7 +386,11 @@ def _compile( For QNN Compilation path, when enable_qnn is set to True, any parameter passed in compiler_options will be ignored. """ - onnx_path = Path(onnx_path if onnx_path else self.get_onnx_path(prefill_only, specializations, offload_pt_weights, use_onnx_subfunctions)) + onnx_path = Path( + onnx_path + if onnx_path + else self.get_onnx_path(prefill_only, specializations, offload_pt_weights, use_onnx_subfunctions) + ) compile_dir = Path(compile_dir or onnx_path.parent) qpc_path = compile_dir / "qpc" if not onnx_path.is_file(): diff --git a/QEfficient/utils/_utils.py b/QEfficient/utils/_utils.py index ea22d932f..4db1f6405 100644 --- a/QEfficient/utils/_utils.py +++ b/QEfficient/utils/_utils.py @@ -559,7 +559,6 @@ def wrapper(self, *args, **kwargs): # Get arguments as a dictionary all_args = bound_args.arguments - export_hash, filtered_hash_params = create_export_hash( model_params=self.hash_params, output_names=all_args.get("output_names"), From 5d014c27b67bdd73248f06a9a3522bb124ae07ac Mon Sep 17 00:00:00 2001 From: Onkar Chougule Date: Wed, 26 Nov 2025 11:47:57 +0000 Subject: [PATCH 30/37] enabled chunking Signed-off-by: Onkar Chougule --- QEfficient/base/modeling_qeff.py | 24 ++++- QEfficient/customop/ctx_scatter_gather.py | 1 + QEfficient/transformers/cache_utils.py | 60 +++++++++++++ .../models/gpt_oss/modeling_gpt_oss.py | 90 ++++++++++++++++++- .../transformers/models/modeling_auto.py | 40 ++++++--- .../transformers/models/pytorch_transforms.py | 13 ++- tests/transformers/models/test_disagg_mode.py | 87 ++++++++++++++++++ 7 files changed, 297 insertions(+), 18 deletions(-) diff --git a/QEfficient/base/modeling_qeff.py b/QEfficient/base/modeling_qeff.py index d4bbb0232..4a6484843 100644 --- a/QEfficient/base/modeling_qeff.py +++ b/QEfficient/base/modeling_qeff.py @@ -330,6 +330,7 @@ def _export( def get_onnx_path( self, prefill_only: Optional[bool] = False, + enable_chunking: Optional[bool] = False, specializations: Optional[List[Dict[str, int]]] = None, offload_pt_weights: Optional[bool] = True, use_onnx_subfunctions: Optional[bool] = False, @@ -337,7 +338,13 @@ def get_onnx_path( kwargs = {"offload_pt_weights": offload_pt_weights, "use_onnx_subfunctions": use_onnx_subfunctions} if prefill_only: if self.prefill_onnx_path is None: - kwargs.update({"prefill_only": prefill_only, "prefill_seq_len": specializations[0].get("seq_len")}) + kwargs.update( + { + "prefill_only": prefill_only, + "prefill_seq_len": specializations[0].get("seq_len"), + "enable_chunking": enable_chunking, + } + ) self.export(**kwargs) return self.prefill_onnx_path else: @@ -361,6 +368,7 @@ def _compile( use_onnx_subfunctions: bool = False, prefill_only: Optional[str] = None, offload_pt_weights: Optional[bool] = True, + enable_chunking: Optional[bool] = False, **compiler_options, ) -> str: """ @@ -389,7 +397,9 @@ def _compile( onnx_path = Path( onnx_path if onnx_path - else self.get_onnx_path(prefill_only, specializations, offload_pt_weights, use_onnx_subfunctions) + else self.get_onnx_path( + prefill_only, enable_chunking, specializations, offload_pt_weights, use_onnx_subfunctions + ) ) compile_dir = Path(compile_dir or onnx_path.parent) qpc_path = compile_dir / "qpc" @@ -492,6 +502,16 @@ def _compile( command.append(f"-aic-binary-dir={qpc_path}") logger.info(f"Running compiler: {' '.join(command)}") + if use_onnx_subfunctions: + + class FeatureNotAvailableError(Exception): + pass + + exec_command = f'QAIC_COMPILER_OPTS_UNSUPPORTED="-loader-inline-all=0" {" ".join(command)}' + raise FeatureNotAvailableError( + f"ONNX graph is exported with subfunctions, assert version of apps SDK should be used for compiling this model. \ + Run following command manually with assert compiler:\n{exec_command}" + ) try: subprocess.run(command, capture_output=True, check=True) except subprocess.CalledProcessError as e: diff --git a/QEfficient/customop/ctx_scatter_gather.py b/QEfficient/customop/ctx_scatter_gather.py index c7dc8639a..7b15effe7 100644 --- a/QEfficient/customop/ctx_scatter_gather.py +++ b/QEfficient/customop/ctx_scatter_gather.py @@ -136,6 +136,7 @@ class CtxGatherFunc(torch.autograd.Function): def forward(data: torch.Tensor, ctx_indices: torch.Tensor, comp_ctx_len: int): batch_indices = torch.arange(data.shape[0]).view(-1, 1, 1) head_indices = torch.arange(data.shape[1]).view(1, -1, 1) + ctx_indices = torch.where(ctx_indices == torch.iinfo(torch.int32).max, 0, ctx_indices) return data[batch_indices, head_indices, ctx_indices] @staticmethod diff --git a/QEfficient/transformers/cache_utils.py b/QEfficient/transformers/cache_utils.py index 90cbdb2dd..fd6a81691 100644 --- a/QEfficient/transformers/cache_utils.py +++ b/QEfficient/transformers/cache_utils.py @@ -778,3 +778,63 @@ def update( v_out = torch.where(invalid_mask.unsqueeze(-1), torch.tensor(0.0, dtype=torch.float32), v_out) return k_out, v_out + + def full_cache_update_chunked( + self, + key_states: torch.Tensor, + value_states: torch.Tensor, + layer_idx: int, + cache_kwargs: Optional[Dict[str, Any]] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + position_ids = cache_kwargs.get("position_ids") + + self.key_cache[layer_idx] = CtxScatterFunc.apply(self.key_cache[layer_idx], position_ids, key_states) + self.value_cache[layer_idx] = CtxScatterFunc.apply(self.value_cache[layer_idx], position_ids, value_states) + + k_out, v_out = self.key_cache[layer_idx], self.value_cache[layer_idx] + + # Gather + ctx_len = cache_kwargs.get("CCL", k_out.shape[2]) + ctx_indices = torch.arange(ctx_len)[None, None, ...] + gather_limit = position_ids.max(1, keepdim=True).values.unsqueeze(1) + invalid_mask = ctx_indices > gather_limit + + invalid_idx_value = InvalidIndexProvider._get_invalid_idx_value() + ctx_indices = torch.where(invalid_mask, invalid_idx_value, ctx_indices) + k_out = CtxGatherFunc.apply(k_out, ctx_indices, ctx_len) + v_out = CtxGatherFunc.apply(v_out, ctx_indices, ctx_len) + v_out = torch.where(invalid_mask.unsqueeze(-1), torch.tensor(0.0, dtype=torch.float32), v_out) + + return k_out, v_out + + def sliding_window_update_chunked( + self, + key_states: torch.Tensor, + value_states: torch.Tensor, + layer_idx: int, + cache_kwargs: Optional[Dict[str, Any]] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + position_ids = cache_kwargs.get("position_ids") + + self.key_cache[layer_idx] = CtxScatterFunc.apply(self.key_cache[layer_idx], position_ids, key_states) + self.value_cache[layer_idx] = CtxScatterFunc.apply(self.value_cache[layer_idx], position_ids, value_states) + + k_out, v_out = self.key_cache[layer_idx], self.value_cache[layer_idx] + sliding_window_len = cache_kwargs.get("sliding_window") + # Gather + ctx_len = position_ids.shape[1] + sliding_window_len + ctx_indices = torch.arange(ctx_len)[None, None, ...] + # positive_pos_ids = torch.where(position_ids<0, 0, position_ids) + first_pos_idx = position_ids[0][0] + add_idx = torch.where(first_pos_idx >= sliding_window_len, first_pos_idx - sliding_window_len, 0) + ctx_indices += add_idx + gather_limit = position_ids.max(1, keepdim=True).values.unsqueeze(1) + invalid_mask = ctx_indices > gather_limit + + invalid_idx_value = InvalidIndexProvider._get_invalid_idx_value() + ctx_indices = torch.where(invalid_mask, invalid_idx_value, ctx_indices) + k_out = CtxGatherFunc.apply(k_out, ctx_indices, ctx_len) + v_out = CtxGatherFunc.apply(v_out, ctx_indices, ctx_len) + v_out = torch.where(invalid_mask.unsqueeze(-1), torch.tensor(0.0, dtype=torch.float32), v_out) + + return k_out, v_out diff --git a/QEfficient/transformers/models/gpt_oss/modeling_gpt_oss.py b/QEfficient/transformers/models/gpt_oss/modeling_gpt_oss.py index efdccf8ec..50c31943a 100644 --- a/QEfficient/transformers/models/gpt_oss/modeling_gpt_oss.py +++ b/QEfficient/transformers/models/gpt_oss/modeling_gpt_oss.py @@ -685,6 +685,88 @@ def opt_eager_attention_forward_blocked( return output, output +class QEffPrefillOnlyChunkedGptOssAttention(GptOssAttention): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __qeff_init__(self): + self.rotary_emb = QEffGptOssRotaryEmbedding(config=self.config) + + def forward( + self, + hidden_states: torch.Tensor, + position_embeddings: tuple[torch.Tensor, torch.Tensor], + attention_mask: Optional[torch.Tensor], + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + batch_index: Optional[torch.LongTensor] = None, + cache_position: Optional[torch.LongTensor] = None, + sliding_mask=None, + **kwargs: Unpack[TransformersKwargs], + ) -> tuple[torch.Tensor, torch.Tensor]: + input_shape = hidden_states.shape[:-1] + hidden_shape = (*input_shape, -1, self.head_dim) + query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2) + hidden_shape = (*input_shape, -1, self.head_dim) + key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2) + value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) + if not (max_seq_len_cached := getattr(self.config, "max_seq_len_cached")): + max_seq_len_cached = 32 * 1024 + cos, sin = self.rotary_emb(value_states, seq_len=max_seq_len_cached) + query_states, key_states = qeff_apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) + + if past_key_value is not None: + # sin and cos are specific to RoPE models; cache_position needed for the static cache + cache_kwargs = { + "sin": sin, + "cos": cos, + "batch_index": batch_index, + "position_ids": position_ids, + "config": self.config, + "is_sliding": self.sliding_window is not None, + "sliding_window": self.sliding_window, + } + if self.sliding_window is not None: + key_states, value_states = past_key_value.sliding_window_update_chunked( + key_states, value_states, self.layer_idx, cache_kwargs + ) + else: + key_states, value_states = past_key_value.full_cache_update_chunked( + key_states, value_states, self.layer_idx, cache_kwargs + ) + + if self.sliding_window is not None: + attention_mask = sliding_mask + # positive_pos_ids = torch.where(position_ids<0, 0, position_ids) + ctx_len = position_ids.shape[1] + self.sliding_window + ctx_indices = torch.arange(ctx_len) + first_pos_idx = position_ids[0][0] + add_idx = torch.where(first_pos_idx >= self.sliding_window, first_pos_idx - self.sliding_window, 0) + # start_idx = torch.where(first_pos_idx>=self.sliding_window, first_pos_idx-self.sliding_window, 0) + # end_idx = torch.where(first_pos_idx >= self.sliding_window, first_pos_idx+position_ids.shape[1], position_ids.shape[1]+self.sliding_window) + ctx_indices += add_idx + attention_mask = attention_mask[:, :, :, ctx_indices] + else: + attention_mask = attention_mask + + attention_interface: Callable = eager_attention_forward + attn_output, attn_weights = attention_interface( + self, + query_states, + key_states, + value_states, + attention_mask, + dropout=0.0 if not self.training else self.attention_dropout, + scaling=self.scaling, + sliding_window=self.sliding_window, + s_aux=self.sinks, # diff with Llama + **kwargs, + ) + + attn_output = attn_output.reshape(*input_shape, -1).contiguous() + attn_output = self.o_proj(attn_output) + return attn_output, attn_weights, past_key_value + + class QEffPrefillOnlyGptOssAttention(GptOssAttention): """Multi-headed attention from 'Attention Is All You Need' paper""" @@ -935,9 +1017,8 @@ def forward( sliding_mask = _create_causal_mask( position_ids=position_ids, target_length=past_key_values.max_cache_len, - sliding_window=past_key_values.sliding_window_len, + sliding_window=self.config.sliding_window, ) - hidden_states = inputs_embeds # decoder layers @@ -1163,12 +1244,13 @@ def forward( def get_pkv_dynamic_axes( self, + chunked_prefill: Optional[bool] = False, ): pkv_dynamic_axes = [] for layer_type in self.config.layer_types: - if layer_type == "sliding_attention": + if layer_type == "sliding_attention" and not chunked_prefill: pkv_dynamic_axes.append({0: "batch_size", 2: "sliding_window"}) - elif layer_type == "full_attention": + else: pkv_dynamic_axes.append({0: "batch_size", 2: "ctx_len"}) return pkv_dynamic_axes diff --git a/QEfficient/transformers/models/modeling_auto.py b/QEfficient/transformers/models/modeling_auto.py index 2ba6d82bb..34eb9597a 100644 --- a/QEfficient/transformers/models/modeling_auto.py +++ b/QEfficient/transformers/models/modeling_auto.py @@ -48,6 +48,7 @@ KVCacheExternalModuleMapperTransform, KVCacheTransform, PoolingTransform, + PrefillOnlyChunkedTransform, PrefillOnlyTransform, RevertPrefillOnlyTransform, SamplerTransform, @@ -2349,9 +2350,12 @@ class QEFFAutoModelForCausalLM(QEFFBaseModel): _onnx_transforms = [FP16ClipTransform, SplitTensorsTransform] - def prefill(self, enable: Optional[bool] = True): + def prefill(self, enable: Optional[bool] = True, enable_chunking: Optional[bool] = False): if enable: - self.model, tf = PrefillOnlyTransform.apply(self.model) + if enable_chunking: + self.model, tf = PrefillOnlyChunkedTransform.apply(self.model) + else: + self.model, tf = PrefillOnlyTransform.apply(self.model) self.prefill_enabled = True else: self.model, tf = RevertPrefillOnlyTransform.apply(self.model) @@ -2560,7 +2564,15 @@ def get_model_config(self) -> dict: """ return self.model.config.__dict__ - def get_seq_len_and_handle_specialized_prefill_model(self, prefill_seq_len: Optional[int] = None) -> int: + def get_seq_len_and_handle_specialized_prefill_model( + self, prefill_seq_len: Optional[int] = None, enable_chunking=False + ) -> int: + self.prefill(enable=True, enable_chunking=enable_chunking) + self.hash_params["prefill_only"] = True + if enable_chunking: + self.hash_params["chunking"] = True + return constants.ONNX_EXPORT_EXAMPLE_SEQ_LEN + num_q_blocks = os.environ.get("NUM_Q_BLOCKS", None) if num_q_blocks is None: block_size = 128 @@ -2587,8 +2599,6 @@ def get_seq_len_and_handle_specialized_prefill_model(self, prefill_seq_len: Opti "seq_len is not divisible by either num_ffn_blocks or num_q_blocks, try chaning the values." ) - self.prefill(True) - self.hash_params["prefill_only"] = True self.hash_params["NUM_Q_BLOCKS"] = num_q_blocks self.hash_params["NUM_FFN_BLOCKS"] = num_ffn_blocks self.hash_params["ENABLE_OPT_SWA"] = os.environ.get("ENABLE_OPT_SWA", "0") @@ -2603,7 +2613,6 @@ def export( export_dir: Optional[str] = None, prefill_only: Optional[bool] = False, prefill_seq_len: Optional[int] = None, - offload_pt_weights: Optional[bool] = True, **kwargs, ) -> str: """ @@ -2628,21 +2637,26 @@ def export( bs: int = constants.ONNX_EXPORT_EXAMPLE_BATCH_SIZE seq_len: int = constants.ONNX_EXPORT_EXAMPLE_SEQ_LEN fbs: int = constants.ONNX_EXPORT_EXAMPLE_FBS + kv_cache_shape = get_padding_shape_from_config( + self.model.config, fbs if self.continuous_batching else bs, seq_len + ) if prefill_only: assert not self.continuous_batching, "prefill_only=True is not supported with continuous_batching=True" seq_len = ( - self.get_seq_len_and_handle_specialized_prefill_model(prefill_seq_len) + self.get_seq_len_and_handle_specialized_prefill_model( + prefill_seq_len=prefill_seq_len, enable_chunking=kwargs.get("enable_chunking", False) + ) if self.model.config.model_type in SPECIALIZED_PREFILL_ONLY_MODEL_ARCH else seq_len ) + kv_cache_shape[2] = ( + seq_len + self.model.config.sliding_window if kwargs.get("enable_chunking", False) else seq_len + ) else: self.prefill(False) self.hash_params.pop("prefill_only", None) self.hash_params.pop("num_blocks", None) - kv_cache_shape = get_padding_shape_from_config( - self.model.config, fbs if self.continuous_batching else bs, seq_len - ) example_inputs = { "input_ids": torch.zeros((bs, seq_len), dtype=torch.int64), "position_ids": torch.arange(seq_len, dtype=torch.int64).view(1, seq_len).repeat(bs, 1), @@ -2691,7 +2705,9 @@ def export( else: # HACK: create common function for this including above if condition code pkv_dynamic_axes = ( - self.model.get_pkv_dynamic_axes() if hasattr(self.model, "get_pkv_dynamic_axes") else pkv_dynamic_axes + self.model.get_pkv_dynamic_axes(chunked_prefill=(prefill_only and kwargs.get("enable_chunking", False))) + if hasattr(self.model, "get_pkv_dynamic_axes") + else pkv_dynamic_axes ) pkv_dynamic_axes = ( [pkv_dynamic_axes] * self.model.config.num_hidden_layers @@ -2945,6 +2961,7 @@ def compile( prefill_only: Optional[bool] = None, use_onnx_subfunctions: bool = False, offload_pt_weights: Optional[bool] = True, + enable_chunking: Optional[bool] = False, **compiler_options, ) -> str: """ @@ -3166,6 +3183,7 @@ def compile( use_onnx_subfunctions=use_onnx_subfunctions, prefill_only=prefill_only, offload_pt_weights=offload_pt_weights, + enable_chunking=enable_chunking, **compiler_options, ) diff --git a/QEfficient/transformers/models/pytorch_transforms.py b/QEfficient/transformers/models/pytorch_transforms.py index 3bdd21d54..761a5098a 100644 --- a/QEfficient/transformers/models/pytorch_transforms.py +++ b/QEfficient/transformers/models/pytorch_transforms.py @@ -261,6 +261,7 @@ QEffGptOssForCausalLM, QEffGptOssMLP, QEffGptOssModel, + QEffPrefillOnlyChunkedGptOssAttention, QEffPrefillOnlyGptOssAttention, QEffPrefillOnlyGptOssMLP, QEffPrefillOnlyGptOssModel, @@ -645,8 +646,18 @@ class PrefillOnlyTransform(ModuleMappingTransform): } +class PrefillOnlyChunkedTransform(ModuleMappingTransform): + _module_mapping = { + QEffGptOssModel: QEffPrefillOnlyGptOssModel, + QEffGptOssAttention: QEffPrefillOnlyChunkedGptOssAttention, + QEffGptOssMLP: QEffPrefillOnlyGptOssMLP, + } + + class RevertPrefillOnlyTransform(ModuleMappingTransform): - _module_mapping = {v: k for k, v in PrefillOnlyTransform._module_mapping.items()} + _module_mapping = {v: k for k, v in PrefillOnlyTransform._module_mapping.items()}.update( + {v: k for k, v in PrefillOnlyChunkedTransform._module_mapping.items()} + ) class SpDTransform: diff --git a/tests/transformers/models/test_disagg_mode.py b/tests/transformers/models/test_disagg_mode.py index 67ee48944..8851365e2 100644 --- a/tests/transformers/models/test_disagg_mode.py +++ b/tests/transformers/models/test_disagg_mode.py @@ -102,3 +102,90 @@ def test_disagg_mode_prefill(model_id, prompt): del prefill_session # Check QAIC output isclose with QEFF pytorch output assert (torch.from_numpy(qpc_out["logits"]) - qeff_out.logits).abs().max() < 5e-2 + + +@pytest.mark.on_qaic +@pytest.mark.parametrize("model_id", [model_id]) +@pytest.mark.parametrize("prompt", prompts) +def test_disagg_mode_prefill_chunked(model_id, prompt): + # Run prefill + tokenizer = AutoTokenizer.from_pretrained(model_id) + PREFILL_SEQ_LEN = 128 + CTX_LEN = 128 * 3 + inputs = tokenizer(prompt, return_tensors="np", padding=True) + padded_len = inputs["input_ids"].shape[1] + num_chunks = -(padded_len // -PREFILL_SEQ_LEN) # ceil divide without float + padded_len = num_chunks * PREFILL_SEQ_LEN # Convert to a multiple of prompt_len + + replace_transformers_quantizers() + model = AutoModelForCausalLM.from_pretrained(model_id, num_hidden_layers=2) + config = model.config + inputs = tokenizer(prompt, return_tensors="np", padding="max_length", max_length=padded_len) + inputs["position_ids"] = np.where(inputs.pop("attention_mask"), np.arange(padded_len), -1) + inputs.pop("token_type_ids", None) + inputs = {k: torch.from_numpy(v).to(model.device) for k, v in inputs.items()} + cache = HybridCache(config=config, batch_size=1, max_cache_len=CTX_LEN) + ins = tokenizer(prompt, return_tensors="pt") + out = model(**ins, past_key_values=cache) + + undo_transformers_quantizers() + + qeff_model = QEFFAutoModelForCausalLM.from_pretrained(model_id, num_hidden_layers=2) + qeff_model.prefill(True, enable_chunking=True) + config = qeff_model.model.config + inputs = tokenizer(prompt, return_tensors="np", padding="max_length", max_length=padded_len) + inputs["position_ids"] = np.where(inputs.pop("attention_mask"), np.arange(padded_len), -1) + inputs.pop("token_type_ids", None) + inputs = {k: torch.from_numpy(v) for k, v in inputs.items()} + past_key_values = [] + for i in range(config.num_hidden_layers): + cache_len = CTX_LEN + pad_shape = (1, 8, cache_len, 64) + past_key = torch.zeros((pad_shape), dtype=torch.float32) + past_value = torch.zeros((pad_shape), dtype=torch.float32) + pkv = (past_key, past_value) + past_key_values.append(pkv) + inputs["past_key_values"] = past_key_values + + for i in range(num_chunks): + chunk_inputs = inputs.copy() + chunk_inputs["input_ids"] = inputs["input_ids"][:, i * PREFILL_SEQ_LEN : (i + 1) * PREFILL_SEQ_LEN] + chunk_inputs["position_ids"] = inputs["position_ids"][:, i * PREFILL_SEQ_LEN : (i + 1) * PREFILL_SEQ_LEN] + + qeff_out = qeff_model.model(**chunk_inputs) + inputs["past_key_values"] = qeff_out["past_key_values"] + + # Check our pytorch implementation + assert (qeff_out.logits - out.logits[:, -1, :]).abs().max() < 1e-4 + + prefill_qpc_path = qeff_model.compile( + prefill_seq_len=PREFILL_SEQ_LEN, + ctx_len=CTX_LEN, + num_cores=16, + mxfp6_matmul=False, + mxint8_kv_cache=False, + num_devices=1, + mos=1, + aic_enable_depth_first=True, + num_speculative_tokens=None, + prefill_only=True, + enable_chunking=True, + ) + prefill_session = QAICInferenceSession(prefill_qpc_path) + prefill_session.skip_buffers( + [x for x in prefill_session.input_names + prefill_session.output_names if x.startswith("past_")] + ) + logits_out_placeholder = np.zeros((1, 1, 201088), dtype=np.float32) + prefill_session.set_buffers({"logits": logits_out_placeholder}) + inputs.pop("past_key_values") + inputs = {k: v.detach().numpy() for k, v in inputs.items()} + st = time.time() + for i in range(num_chunks): + chunk_inputs = inputs.copy() + chunk_inputs["input_ids"] = inputs["input_ids"][:, i * PREFILL_SEQ_LEN : (i + 1) * PREFILL_SEQ_LEN] + chunk_inputs["position_ids"] = inputs["position_ids"][:, i * PREFILL_SEQ_LEN : (i + 1) * PREFILL_SEQ_LEN] + qpc_out = prefill_session.run(chunk_inputs) + print(f"time for prefill_run={time.time() - st} sec\n") + del prefill_session + # Check QAIC output isclose with QEFF pytorch output + assert (torch.from_numpy(qpc_out["logits"]) - qeff_out.logits).abs().max() < 8e-2 From f1b1785f470dc70f8821ccf160f6b0917831e8a2 Mon Sep 17 00:00:00 2001 From: Onkar Chougule Date: Mon, 1 Dec 2025 10:35:29 +0000 Subject: [PATCH 31/37] added ChunkedPrefillMLP block; fixed passing prefill_only flag and enable_chunking flag to get_specialization for gpt-oss Signed-off-by: Onkar Chougule --- .../models/gpt_oss/modeling_gpt_oss.py | 65 +++++++++++++++++-- .../transformers/models/modeling_auto.py | 2 + .../transformers/models/pytorch_transforms.py | 3 +- 3 files changed, 63 insertions(+), 7 deletions(-) diff --git a/QEfficient/transformers/models/gpt_oss/modeling_gpt_oss.py b/QEfficient/transformers/models/gpt_oss/modeling_gpt_oss.py index 50c31943a..853fe12cd 100644 --- a/QEfficient/transformers/models/gpt_oss/modeling_gpt_oss.py +++ b/QEfficient/transformers/models/gpt_oss/modeling_gpt_oss.py @@ -45,6 +45,61 @@ def __qeff_init__(self): self.up_proj_bias = nn.Parameter(torch.empty(self.num_experts, self.expert_dim)) +class QEffPrefillOnlyChunkedGptOssMLP(GptOssMLP): + def forward(self, hidden: torch.Tensor): + B, S, H = hidden.shape + T = B * S + hidden = hidden.view(T, H) + + # Router computation + router_logits = F.linear(hidden, self.router.weight, self.router.bias) + + # Top-k selection + top_w, top_i = torch.topk(router_logits, self.router.top_k, dim=-1) # both [T, K] + top_w = torch.nn.functional.softmax(top_w, dim=1, dtype=top_w.dtype) + + masked_logits = torch.zeros_like(router_logits) + masked_logits.scatter_(1, top_i, top_w) + + # Routing weights for each expert [T, E] + routing_weights = masked_logits + + # ────────────────── allocate the output tensor ───── + expert_out = hidden.new_zeros((T, H)) # accumulation buffer + + # ───────────────────────── Expert computation loop ───────────────────────────── + for e in range(self.experts.num_experts): + routing_weight = routing_weights[:, e].unsqueeze(-1) # [T, 1] + + W_g, W_u = self.experts.gate_proj[e], self.experts.up_proj[e] # [H, I], [H, I] + b_g, b_u = self.experts.gate_proj_bias[e], self.experts.up_proj_bias[e] # [I], [I] + W_d = self.experts.down_proj[e] # [I, H] + b_d = self.experts.down_proj_bias[e] # [H] + + # Gate and Up projections + gate = (hidden @ W_g) + b_g # [T, I] + up = (hidden @ W_u) + b_u # [T, I] + + # Apply GptOss activation with clamping + gate = gate.clamp(min=torch.finfo(torch.float16).min, max=self.experts.limit) + up = up.clamp(min=-self.experts.limit, max=self.experts.limit) + + # GLU activation + glu = gate * torch.sigmoid(gate * self.experts.alpha) + intermediate = (up + 1) * glu # [T, I] + + # Down projection + down_out = (intermediate @ W_d) + b_d # [T, H] + + # Apply routing weights and accumulate + masked_down = torch.where(routing_weight > 0, down_out * routing_weight, torch.zeros_like(expert_out)) + expert_out += masked_down + + # original shape [B, S, H] + return expert_out.view(B, S, H), router_logits + + + class QEffPrefillOnlyGptOssMLP(GptOssMLP): def forward(self, hidden: torch.Tensor): if os.environ.get("NUM_FFN_BLOCKS", None) is not None: @@ -152,7 +207,7 @@ def blocked_ffn_forward(self, hidden: torch.Tensor): up = (tgb @ W_u) + b_u # [T, I] # Apply GptOss activation with clamping - gate = gate.clamp(min=None, max=self.experts.limit) + gate = gate.clamp(min=torch.finfo(torch.float16).min, max=self.experts.limit) up = up.clamp(min=-self.experts.limit, max=self.experts.limit) # GLU activation @@ -234,7 +289,7 @@ def blocked_ffn_forward_block_weights(self, hidden: torch.Tensor): cur_gate = (tgb @ W_g[:, i * 128 : (i + 1) * 128]) + b_g[i * 128 : (i + 1) * 128] cur_up = (tgb @ W_u[:, i * 128 : (i + 1) * 128]) + b_u[i * 128 : (i + 1) * 128] - cur_gate = cur_gate.clamp(min=None, max=self.experts.limit) + cur_gate = cur_gate.clamp(min=torch.finfo(torch.float16).min, max=self.experts.limit) cur_up = cur_up.clamp(min=-self.experts.limit, max=self.experts.limit) cur_glu = cur_gate * torch.sigmoid(cur_gate * self.experts.alpha) cur_intermediate = (cur_up + 1) * cur_glu @@ -339,7 +394,7 @@ def forward(self, hidden_states): up = torch.bmm(expert_in, up_proj) + up_proj_bias.unsqueeze(1) # Apply activation with clamping - gate = gate.clamp(min=None, max=self.experts.limit) + gate = gate.clamp(min=torch.finfo(torch.float16).min, max=self.experts.limit) up = up.clamp(min=-self.experts.limit, max=self.experts.limit) # GLU activation @@ -1262,13 +1317,11 @@ def get_specializations( **kwargs, ): batch_size = batch_size if batch_size else 1 - prefill_seq_len = prefill_seq_len if prefill_seq_len else constants.PROMPT_LEN - if kwargs.get("prefill_only") and ctx_len != prefill_seq_len: + if kwargs.get("prefill_only") and not kwargs.get("enable_chunking") and ctx_len != prefill_seq_len: ctx_len = prefill_seq_len logger.warning( f"overriding ctx_len={prefill_seq_len}, currently we don't support ctx_len different than prefill_seq_len for prefill_only model" ) - ctx_len = ctx_len if ctx_len else constants.CTX_LEN specializations = [ { diff --git a/QEfficient/transformers/models/modeling_auto.py b/QEfficient/transformers/models/modeling_auto.py index 34eb9597a..4809e1b96 100644 --- a/QEfficient/transformers/models/modeling_auto.py +++ b/QEfficient/transformers/models/modeling_auto.py @@ -3127,6 +3127,8 @@ def compile( batch_size=batch_size, kv_cache_batch_size=kv_cache_batch_size, full_batch_size=full_batch_size, + prefill_only=prefill_only, + enable_chunking=enable_chunking, ) ) diff --git a/QEfficient/transformers/models/pytorch_transforms.py b/QEfficient/transformers/models/pytorch_transforms.py index 761a5098a..a033f2364 100644 --- a/QEfficient/transformers/models/pytorch_transforms.py +++ b/QEfficient/transformers/models/pytorch_transforms.py @@ -262,6 +262,7 @@ QEffGptOssMLP, QEffGptOssModel, QEffPrefillOnlyChunkedGptOssAttention, + QEffPrefillOnlyChunkedGptOssMLP, QEffPrefillOnlyGptOssAttention, QEffPrefillOnlyGptOssMLP, QEffPrefillOnlyGptOssModel, @@ -650,7 +651,7 @@ class PrefillOnlyChunkedTransform(ModuleMappingTransform): _module_mapping = { QEffGptOssModel: QEffPrefillOnlyGptOssModel, QEffGptOssAttention: QEffPrefillOnlyChunkedGptOssAttention, - QEffGptOssMLP: QEffPrefillOnlyGptOssMLP, + QEffGptOssMLP: QEffPrefillOnlyChunkedGptOssMLP, } From 723f4adb8a6ed217963383e0b0a50448c66f8f62 Mon Sep 17 00:00:00 2001 From: Onkar Chougule Date: Tue, 2 Dec 2025 11:10:09 +0000 Subject: [PATCH 32/37] added disagg mode example for chunking mode Signed-off-by: Onkar Chougule --- .../models/gpt_oss/modeling_gpt_oss.py | 1 - .../transformers/models/pytorch_transforms.py | 7 +- examples/gpt_oss_disagg_mode.py | 1 + examples/gpt_oss_disagg_mode_with_chunking.py | 156 ++++++++++++++++++ 4 files changed, 161 insertions(+), 4 deletions(-) create mode 100644 examples/gpt_oss_disagg_mode_with_chunking.py diff --git a/QEfficient/transformers/models/gpt_oss/modeling_gpt_oss.py b/QEfficient/transformers/models/gpt_oss/modeling_gpt_oss.py index 853fe12cd..fb4c0acf5 100644 --- a/QEfficient/transformers/models/gpt_oss/modeling_gpt_oss.py +++ b/QEfficient/transformers/models/gpt_oss/modeling_gpt_oss.py @@ -99,7 +99,6 @@ def forward(self, hidden: torch.Tensor): return expert_out.view(B, S, H), router_logits - class QEffPrefillOnlyGptOssMLP(GptOssMLP): def forward(self, hidden: torch.Tensor): if os.environ.get("NUM_FFN_BLOCKS", None) is not None: diff --git a/QEfficient/transformers/models/pytorch_transforms.py b/QEfficient/transformers/models/pytorch_transforms.py index a033f2364..16b0401cc 100644 --- a/QEfficient/transformers/models/pytorch_transforms.py +++ b/QEfficient/transformers/models/pytorch_transforms.py @@ -656,9 +656,10 @@ class PrefillOnlyChunkedTransform(ModuleMappingTransform): class RevertPrefillOnlyTransform(ModuleMappingTransform): - _module_mapping = {v: k for k, v in PrefillOnlyTransform._module_mapping.items()}.update( - {v: k for k, v in PrefillOnlyChunkedTransform._module_mapping.items()} - ) + _module_mapping = { + **{v: k for k, v in PrefillOnlyTransform._module_mapping.items()}, + **{v: k for k, v in PrefillOnlyChunkedTransform._module_mapping.items()}, + } class SpDTransform: diff --git a/examples/gpt_oss_disagg_mode.py b/examples/gpt_oss_disagg_mode.py index ee03f573a..fd0d5b045 100644 --- a/examples/gpt_oss_disagg_mode.py +++ b/examples/gpt_oss_disagg_mode.py @@ -80,6 +80,7 @@ aic_enable_depth_first=True, num_speculative_tokens=None, prefill_only=True, + use_onnx_subfunctions=True, ) prefill_session = QAICInferenceSession(prefill_qpc_path) diff --git a/examples/gpt_oss_disagg_mode_with_chunking.py b/examples/gpt_oss_disagg_mode_with_chunking.py new file mode 100644 index 000000000..36be5244d --- /dev/null +++ b/examples/gpt_oss_disagg_mode_with_chunking.py @@ -0,0 +1,156 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +# SPDX-License-Identifier: BSD-3-Clause +# +# ----------------------------------------------------------------------------- + +import time + +import numpy as np +import torch +from transformers import AutoTokenizer + +from QEfficient import QEFFAutoModelForCausalLM +from QEfficient.generation.cloud_infer import QAICInferenceSession + +model_id = "openai/gpt-oss-120b" # weights are not required to convert to fp32 + +prompt = """ +Once upon a time, in a small town, there lived a young boy named Alex. Alex was a curious and adventurous child, always eager to explore the world around him. One day, while playing in the park, Alex stumbled upon a mysterious old book hidden beneath a pile of leaves. The book was filled with stories of distant lands, magical creatures, and extraordinary adventures. + +As Alex flipped through the pages, he discovered a map that led to a hidden treasure. Excited by the prospect of a real-life treasure hunt, Alex decided to embark on a thrilling journey. He packed his backpack with snacks, a flashlight, and a compass, and set off into the unknown. + +The path to the treasure was not an easy one. Alex had to navigate through dense forests, cross rickety bridges, and solve riddles that guarded the treasure's location. +""" +all_outputs = [] +# Run prefill +tokenizer = AutoTokenizer.from_pretrained(model_id) +PREFILL_SEQ_LEN = 128 +CTX_LEN = 2 * 128 +inputs = tokenizer(prompt, return_tensors="np", padding=True) +position_ids = inputs["attention_mask"].sum(1, keepdims=True) +padded_len = inputs["input_ids"].shape[1] +num_chunks = -(padded_len // -PREFILL_SEQ_LEN) # ceil divide without float +padded_len = num_chunks * PREFILL_SEQ_LEN # Convert to a multiple of prompt_len + +# Initialize variables specific to request +# Calculate the max generation length. +max_gen_len = CTX_LEN - position_ids.max() +generation_len = max_gen_len + + +# qeff_model = QEFFAutoModelForCausalLM.from_pretrained(model_id, num_hidden_layers=2) +qeff_model = QEFFAutoModelForCausalLM.from_pretrained(model_id) + + +decode_qpc_path = qeff_model.compile( + prefill_seq_len=1, + ctx_len=CTX_LEN, + num_cores=16, + mxfp6_matmul=True, + mxint8_kv_cache=True, + num_devices=1, + mos=1, + aic_enable_depth_first=True, + num_speculative_tokens=None, + offload_pt_weights=False, # Need the weights in memory for prefill-model export/compilation in the next step +) + +config = qeff_model.model.config +inputs = tokenizer(prompt, return_tensors="np", padding="max_length", max_length=padded_len) +inputs["position_ids"] = np.where(inputs.pop("attention_mask"), np.arange(padded_len), -1) +inputs.pop("token_type_ids", None) +inputs = {k: torch.from_numpy(v) for k, v in inputs.items()} +past_key_values = [] +for i in range(config.num_hidden_layers): + cache_len = config.sliding_window if i % 2 == 0 else PREFILL_SEQ_LEN + pad_shape = (1, 8, cache_len, 64) + past_key = torch.zeros((pad_shape), dtype=torch.float32) + past_value = torch.zeros((pad_shape), dtype=torch.float32) + pkv = (past_key, past_value) + past_key_values.append(pkv) +inputs["past_key_values"] = past_key_values + +prefill_qpc_path = qeff_model.compile( + prefill_seq_len=PREFILL_SEQ_LEN, + ctx_len=CTX_LEN, + num_cores=16, + mxfp6_matmul=True, + mxint8_kv_cache=True, + num_devices=1, + mos=1, + aic_enable_depth_first=True, + num_speculative_tokens=None, + prefill_only=True, + enable_chunking=True, + use_onnx_subfunctions=True, + offload_pt_weights=False, +) +print("loading qpc") +st = time.time() +prefill_session = QAICInferenceSession(prefill_qpc_path, device_ids=[i for i in range(32, 48)]) +print(f"time for loading session = {time.time() - st}") +print("done") +prefill_session.skip_buffers( + [x for x in prefill_session.input_names + prefill_session.output_names if x.startswith("past_")] +) +logits_out_placeholder = np.zeros((1, 1, 201088), dtype=np.float32) +prefill_session.set_buffers({"logits": logits_out_placeholder}) +inputs.pop("past_key_values") +inputs = {k: v.detach().numpy() for k, v in inputs.items()} +st = time.time() + +for i in range(num_chunks): + chunk_inputs = inputs.copy() + chunk_inputs["input_ids"] = inputs["input_ids"][:, i * PREFILL_SEQ_LEN : (i + 1) * PREFILL_SEQ_LEN] + chunk_inputs["position_ids"] = inputs["position_ids"][:, i * PREFILL_SEQ_LEN : (i + 1) * PREFILL_SEQ_LEN] + ins = time.time() + qpc_out = prefill_session.run(chunk_inputs) + print(f"time for this run={time.time() - ins}") +print(f"time for prefill_run={time.time() - st} sec\n") + +decode_session = QAICInferenceSession(decode_qpc_path) +decode_session.set_buffers({"logits": logits_out_placeholder}) + +decode_inputs = { + "input_ids": np.argmax(qpc_out["logits"]).reshape(1, 1), + "position_ids": np.max(inputs["position_ids"]).reshape(1, 1) + 1, +} +print("pos_id for decodee", decode_inputs["position_ids"]) + +all_outputs.append(decode_inputs["input_ids"][0][0]) +for i in range(config.num_hidden_layers): + if i % 2 == 0 and decode_inputs["position_ids"] >= config.sliding_window: + last_valid_pos_idx = decode_inputs["position_ids"][0][0] + first_valid_pos_idx = last_valid_pos_idx - config.sliding_window + k = qpc_out[f"past_key.{i}_RetainedState"][:, :, first_valid_pos_idx:last_valid_pos_idx, :] + v = qpc_out[f"past_value.{i}_RetainedState"][:, :, first_valid_pos_idx:last_valid_pos_idx, :] + mod_pos_id = config.sliding_window - decode_inputs["position_ids"][0][0] % config.sliding_window + decode_inputs[f"past_key.{i}"] = np.concatenate((k[:, :, mod_pos_id:, :], k[:, :, :mod_pos_id, :]), axis=-2) + decode_inputs[f"past_value.{i}"] = np.concatenate((v[:, :, mod_pos_id:, :], v[:, :, :mod_pos_id, :]), axis=-2) + else: + decode_inputs[f"past_key.{i}"] = qpc_out[f"past_key.{i}_RetainedState"] + decode_inputs[f"past_value.{i}"] = qpc_out[f"past_value.{i}_RetainedState"] + +st = time.time() +decode_out = decode_session.run(decode_inputs) +print(f"time for first run of decode with KV as input = {time.time() - st} sec\n") +decode_session.skip_buffers( + [x for x in decode_session.input_names + decode_session.output_names if x.startswith("past_")] +) +pos_id = np.max(decode_inputs["position_ids"]).reshape(1, 1) + 1 +st = time.time() +for i in range(generation_len - 2): + loop_decode_inputs = { + "input_ids": np.argmax(decode_out["logits"]).reshape(1, 1), + "position_ids": pos_id, + } + all_outputs.append(loop_decode_inputs["input_ids"][0][0]) + decode_out = decode_session.run(loop_decode_inputs) + pos_id += 1 + + +print(f"time for decode generation = {(time.time() - st) / (generation_len - 2)}") +print(all_outputs) +print(tokenizer.decode(all_outputs)) From 1bc8ee9407303b64bbade9d9429ee2255e31936c Mon Sep 17 00:00:00 2001 From: Onkar Chougule Date: Tue, 2 Dec 2025 14:00:01 +0000 Subject: [PATCH 33/37] fixed the kwargs passing to build_decode_specialization Signed-off-by: Onkar Chougule --- QEfficient/transformers/models/modeling_auto.py | 1 + 1 file changed, 1 insertion(+) diff --git a/QEfficient/transformers/models/modeling_auto.py b/QEfficient/transformers/models/modeling_auto.py index 4809e1b96..da83105e0 100644 --- a/QEfficient/transformers/models/modeling_auto.py +++ b/QEfficient/transformers/models/modeling_auto.py @@ -2890,6 +2890,7 @@ def build_decode_specialization( kv_cache_batch_size: Optional[int] = None, full_batch_size: Optional[int] = None, num_speculative_tokens: Optional[int] = None, + **kwargs, ): """ Builds a dictionary representing a compilation specialization for the decode phase. From abccb26b6bede848aae8d3b450f7048ca6439525 Mon Sep 17 00:00:00 2001 From: Onkar Chougule Date: Mon, 8 Dec 2025 07:54:44 +0000 Subject: [PATCH 34/37] pushed latest changes with chunking enabled for prefill along with retaining full KV for decode-only model Signed-off-by: Onkar Chougule --- QEfficient/base/modeling_qeff.py | 15 +- .../models/gpt_oss/modeling_gpt_oss.py | 4 +- .../transformers/models/modeling_auto.py | 36 ++++- .../transformers/models/pytorch_transforms.py | 10 ++ examples/gpt_oss_disagg_mode_with_chunking.py | 131 ++++++++---------- tests/transformers/models/test_disagg_mode.py | 3 +- 6 files changed, 112 insertions(+), 87 deletions(-) diff --git a/QEfficient/base/modeling_qeff.py b/QEfficient/base/modeling_qeff.py index 4a6484843..af9cebf79 100644 --- a/QEfficient/base/modeling_qeff.py +++ b/QEfficient/base/modeling_qeff.py @@ -334,8 +334,13 @@ def get_onnx_path( specializations: Optional[List[Dict[str, int]]] = None, offload_pt_weights: Optional[bool] = True, use_onnx_subfunctions: Optional[bool] = False, + retain_full_kv: Optional[bool] = False, ): - kwargs = {"offload_pt_weights": offload_pt_weights, "use_onnx_subfunctions": use_onnx_subfunctions} + kwargs = { + "offload_pt_weights": offload_pt_weights, + "use_onnx_subfunctions": use_onnx_subfunctions, + "retain_full_kv": retain_full_kv, + } if prefill_only: if self.prefill_onnx_path is None: kwargs.update( @@ -369,6 +374,7 @@ def _compile( prefill_only: Optional[str] = None, offload_pt_weights: Optional[bool] = True, enable_chunking: Optional[bool] = False, + retain_full_kv: Optional[bool] = None, **compiler_options, ) -> str: """ @@ -398,7 +404,12 @@ def _compile( onnx_path if onnx_path else self.get_onnx_path( - prefill_only, enable_chunking, specializations, offload_pt_weights, use_onnx_subfunctions + prefill_only, + enable_chunking, + specializations, + offload_pt_weights, + use_onnx_subfunctions, + retain_full_kv, ) ) compile_dir = Path(compile_dir or onnx_path.parent) diff --git a/QEfficient/transformers/models/gpt_oss/modeling_gpt_oss.py b/QEfficient/transformers/models/gpt_oss/modeling_gpt_oss.py index fb4c0acf5..309eab7af 100644 --- a/QEfficient/transformers/models/gpt_oss/modeling_gpt_oss.py +++ b/QEfficient/transformers/models/gpt_oss/modeling_gpt_oss.py @@ -1298,11 +1298,11 @@ def forward( def get_pkv_dynamic_axes( self, - chunked_prefill: Optional[bool] = False, + retain_full_kv: Optional[bool] = False, ): pkv_dynamic_axes = [] for layer_type in self.config.layer_types: - if layer_type == "sliding_attention" and not chunked_prefill: + if layer_type == "sliding_attention" and not retain_full_kv: pkv_dynamic_axes.append({0: "batch_size", 2: "sliding_window"}) else: pkv_dynamic_axes.append({0: "batch_size", 2: "ctx_len"}) diff --git a/QEfficient/transformers/models/modeling_auto.py b/QEfficient/transformers/models/modeling_auto.py index da83105e0..4c0a424b7 100644 --- a/QEfficient/transformers/models/modeling_auto.py +++ b/QEfficient/transformers/models/modeling_auto.py @@ -50,6 +50,7 @@ PoolingTransform, PrefillOnlyChunkedTransform, PrefillOnlyTransform, + RevertPrefillKeepAttentionTransform, RevertPrefillOnlyTransform, SamplerTransform, SpDTransform, @@ -2350,7 +2351,12 @@ class QEFFAutoModelForCausalLM(QEFFBaseModel): _onnx_transforms = [FP16ClipTransform, SplitTensorsTransform] - def prefill(self, enable: Optional[bool] = True, enable_chunking: Optional[bool] = False): + def prefill( + self, + enable: Optional[bool] = True, + enable_chunking: Optional[bool] = False, + retain_full_kv: Optional[bool] = False, + ): if enable: if enable_chunking: self.model, tf = PrefillOnlyChunkedTransform.apply(self.model) @@ -2358,7 +2364,10 @@ def prefill(self, enable: Optional[bool] = True, enable_chunking: Optional[bool] self.model, tf = PrefillOnlyTransform.apply(self.model) self.prefill_enabled = True else: - self.model, tf = RevertPrefillOnlyTransform.apply(self.model) + if retain_full_kv: + self.model, tf = RevertPrefillKeepAttentionTransform.apply(self.model) + else: + self.model, tf = RevertPrefillOnlyTransform.apply(self.model) self.prefill_enabled = False def __init__( @@ -2533,7 +2542,6 @@ def from_pretrained( qaic_config["pretrained_model_name_or_path"] = pretrained_model_name_or_path # This is support models that should be classified to in a different auto class but transformers load them via this class - if model.__class__.__name__ in MISCLASSIFIED_CAUSAL_LM_TO_QEFF_AUTO_CLASS_MAP: return MISCLASSIFIED_CAUSAL_LM_TO_QEFF_AUTO_CLASS_MAP[model.__class__.__name__]( model, @@ -2567,7 +2575,6 @@ def get_model_config(self) -> dict: def get_seq_len_and_handle_specialized_prefill_model( self, prefill_seq_len: Optional[int] = None, enable_chunking=False ) -> int: - self.prefill(enable=True, enable_chunking=enable_chunking) self.hash_params["prefill_only"] = True if enable_chunking: self.hash_params["chunking"] = True @@ -2642,6 +2649,8 @@ def export( ) if prefill_only: assert not self.continuous_batching, "prefill_only=True is not supported with continuous_batching=True" + self.prefill(enable=True, enable_chunking=kwargs.get("enable_chunking", False)) + self.hash_params.pop("retain_full_kv", None) seq_len = ( self.get_seq_len_and_handle_specialized_prefill_model( prefill_seq_len=prefill_seq_len, enable_chunking=kwargs.get("enable_chunking", False) @@ -2653,9 +2662,15 @@ def export( seq_len + self.model.config.sliding_window if kwargs.get("enable_chunking", False) else seq_len ) else: - self.prefill(False) + self.prefill(False, retain_full_kv=kwargs.get("retain_full_kv", False)) self.hash_params.pop("prefill_only", None) - self.hash_params.pop("num_blocks", None) + self.hash_params.pop("NUM_Q_BLOCKS", None) + self.hash_params.pop("NUM_FFN_BLOCKS", None) + self.hash_params.pop("ENABLE_OPT_SWA", None) + self.hash_params.pop("chunking", None) + if kwargs.get("retain_full_kv", False): + kv_cache_shape[2] = seq_len + self.model.config.sliding_window + self.hash_params["retain_full_kv"] = True example_inputs = { "input_ids": torch.zeros((bs, seq_len), dtype=torch.int64), @@ -2705,7 +2720,10 @@ def export( else: # HACK: create common function for this including above if condition code pkv_dynamic_axes = ( - self.model.get_pkv_dynamic_axes(chunked_prefill=(prefill_only and kwargs.get("enable_chunking", False))) + self.model.get_pkv_dynamic_axes( + retain_full_kv=kwargs.get("retain_full_kv", False) + or (prefill_only and kwargs.get("enable_chunking", False)) + ) if hasattr(self.model, "get_pkv_dynamic_axes") else pkv_dynamic_axes ) @@ -2963,6 +2981,7 @@ def compile( use_onnx_subfunctions: bool = False, offload_pt_weights: Optional[bool] = True, enable_chunking: Optional[bool] = False, + retain_full_kv: Optional[bool] = None, **compiler_options, ) -> str: """ @@ -3109,6 +3128,8 @@ def compile( if self.comp_ctx_lengths_prefill is not None: # Adding elements from self.comp_ctx_lengths_prefill to prefill_specialization for i in range(0, len(self.comp_ctx_lengths_prefill)): + if prefill_only or enable_chunking: + raise NotImplementedError("prefill_only or enable_chunking is not supported with CCL") specializations.append( self.build_prefill_specialization( prefill_seq_len=prefill_seq_len, @@ -3187,6 +3208,7 @@ def compile( prefill_only=prefill_only, offload_pt_weights=offload_pt_weights, enable_chunking=enable_chunking, + retain_full_kv=retain_full_kv, **compiler_options, ) diff --git a/QEfficient/transformers/models/pytorch_transforms.py b/QEfficient/transformers/models/pytorch_transforms.py index 16b0401cc..d9b96e42d 100644 --- a/QEfficient/transformers/models/pytorch_transforms.py +++ b/QEfficient/transformers/models/pytorch_transforms.py @@ -655,6 +655,16 @@ class PrefillOnlyChunkedTransform(ModuleMappingTransform): } +class RevertPrefillKeepAttentionTransform(ModuleMappingTransform): + _module_mapping = { + QEffGptOssModel: QEffPrefillOnlyGptOssModel, + QEffPrefillOnlyGptOssAttention: QEffPrefillOnlyChunkedGptOssAttention, + QEffGptOssAttention: QEffPrefillOnlyChunkedGptOssAttention, + QEffPrefillOnlyGptOssMLP: QEffGptOssMLP, + QEffPrefillOnlyChunkedGptOssMLP: QEffGptOssMLP, + } + + class RevertPrefillOnlyTransform(ModuleMappingTransform): _module_mapping = { **{v: k for k, v in PrefillOnlyTransform._module_mapping.items()}, diff --git a/examples/gpt_oss_disagg_mode_with_chunking.py b/examples/gpt_oss_disagg_mode_with_chunking.py index 36be5244d..363e2806c 100644 --- a/examples/gpt_oss_disagg_mode_with_chunking.py +++ b/examples/gpt_oss_disagg_mode_with_chunking.py @@ -9,12 +9,12 @@ import numpy as np import torch -from transformers import AutoTokenizer +from transformers import AutoConfig, AutoTokenizer from QEfficient import QEFFAutoModelForCausalLM from QEfficient.generation.cloud_infer import QAICInferenceSession -model_id = "openai/gpt-oss-120b" # weights are not required to convert to fp32 +model_id = "openai/gpt-oss-20b" # weights are not required to convert to fp32 prompt = """ Once upon a time, in a small town, there lived a young boy named Alex. Alex was a curious and adventurous child, always eager to explore the world around him. One day, while playing in the park, Alex stumbled upon a mysterious old book hidden beneath a pile of leaves. The book was filled with stories of distant lands, magical creatures, and extraordinary adventures. @@ -23,27 +23,14 @@ The path to the treasure was not an easy one. Alex had to navigate through dense forests, cross rickety bridges, and solve riddles that guarded the treasure's location. """ -all_outputs = [] # Run prefill +config = AutoConfig.from_pretrained(model_id) tokenizer = AutoTokenizer.from_pretrained(model_id) PREFILL_SEQ_LEN = 128 -CTX_LEN = 2 * 128 -inputs = tokenizer(prompt, return_tensors="np", padding=True) -position_ids = inputs["attention_mask"].sum(1, keepdims=True) -padded_len = inputs["input_ids"].shape[1] -num_chunks = -(padded_len // -PREFILL_SEQ_LEN) # ceil divide without float -padded_len = num_chunks * PREFILL_SEQ_LEN # Convert to a multiple of prompt_len +CTX_LEN = 128 * 3 -# Initialize variables specific to request -# Calculate the max generation length. -max_gen_len = CTX_LEN - position_ids.max() -generation_len = max_gen_len - - -# qeff_model = QEFFAutoModelForCausalLM.from_pretrained(model_id, num_hidden_layers=2) qeff_model = QEFFAutoModelForCausalLM.from_pretrained(model_id) - decode_qpc_path = qeff_model.compile( prefill_seq_len=1, ctx_len=CTX_LEN, @@ -55,23 +42,12 @@ aic_enable_depth_first=True, num_speculative_tokens=None, offload_pt_weights=False, # Need the weights in memory for prefill-model export/compilation in the next step + retain_full_kv=True, ) -config = qeff_model.model.config -inputs = tokenizer(prompt, return_tensors="np", padding="max_length", max_length=padded_len) -inputs["position_ids"] = np.where(inputs.pop("attention_mask"), np.arange(padded_len), -1) -inputs.pop("token_type_ids", None) -inputs = {k: torch.from_numpy(v) for k, v in inputs.items()} -past_key_values = [] -for i in range(config.num_hidden_layers): - cache_len = config.sliding_window if i % 2 == 0 else PREFILL_SEQ_LEN - pad_shape = (1, 8, cache_len, 64) - past_key = torch.zeros((pad_shape), dtype=torch.float32) - past_value = torch.zeros((pad_shape), dtype=torch.float32) - pkv = (past_key, past_value) - past_key_values.append(pkv) -inputs["past_key_values"] = past_key_values +# Following command errors out by default, the user is supposed to run the printed command and provide the generated qpc path as prefill_qpc_path commenting out lines 55-68 +# prefill_qpc_path = "provide path here" prefill_qpc_path = qeff_model.compile( prefill_seq_len=PREFILL_SEQ_LEN, ctx_len=CTX_LEN, @@ -85,22 +61,27 @@ prefill_only=True, enable_chunking=True, use_onnx_subfunctions=True, - offload_pt_weights=False, -) -print("loading qpc") -st = time.time() -prefill_session = QAICInferenceSession(prefill_qpc_path, device_ids=[i for i in range(32, 48)]) -print(f"time for loading session = {time.time() - st}") -print("done") -prefill_session.skip_buffers( - [x for x in prefill_session.input_names + prefill_session.output_names if x.startswith("past_")] ) -logits_out_placeholder = np.zeros((1, 1, 201088), dtype=np.float32) -prefill_session.set_buffers({"logits": logits_out_placeholder}) -inputs.pop("past_key_values") + + +inputs = tokenizer(prompt, return_tensors="np", padding=True) +position_ids = inputs["attention_mask"].sum(1, keepdims=True) +generation_len = CTX_LEN - position_ids.max() +padded_len = inputs["input_ids"].shape[1] +num_chunks = -(padded_len // -PREFILL_SEQ_LEN) # ceil divide without float +padded_len = num_chunks * PREFILL_SEQ_LEN # Convert to a multiple of prompt_len +inputs = tokenizer(prompt, return_tensors="np", padding="max_length", max_length=padded_len) +inputs["position_ids"] = np.where(inputs.pop("attention_mask"), np.arange(padded_len), -1) +inputs.pop("token_type_ids", None) +inputs = {k: torch.from_numpy(v) for k, v in inputs.items()} +inputs.pop("past_key_values", None) inputs = {k: v.detach().numpy() for k, v in inputs.items()} -st = time.time() + +decode_session = QAICInferenceSession(decode_qpc_path) +prefill_session = QAICInferenceSession(prefill_qpc_path) + +all_outputs = [] for i in range(num_chunks): chunk_inputs = inputs.copy() chunk_inputs["input_ids"] = inputs["input_ids"][:, i * PREFILL_SEQ_LEN : (i + 1) * PREFILL_SEQ_LEN] @@ -108,49 +89,49 @@ ins = time.time() qpc_out = prefill_session.run(chunk_inputs) print(f"time for this run={time.time() - ins}") -print(f"time for prefill_run={time.time() - st} sec\n") - -decode_session = QAICInferenceSession(decode_qpc_path) -decode_session.set_buffers({"logits": logits_out_placeholder}) + for i in range(config.num_hidden_layers): + inputs[f"past_key.{i}"] = qpc_out[f"past_key.{i}_RetainedState"] + inputs[f"past_value.{i}"] = qpc_out[f"past_value.{i}_RetainedState"] +all_outputs.append(np.argmax(qpc_out["logits"])) decode_inputs = { "input_ids": np.argmax(qpc_out["logits"]).reshape(1, 1), "position_ids": np.max(inputs["position_ids"]).reshape(1, 1) + 1, } -print("pos_id for decodee", decode_inputs["position_ids"]) - -all_outputs.append(decode_inputs["input_ids"][0][0]) for i in range(config.num_hidden_layers): - if i % 2 == 0 and decode_inputs["position_ids"] >= config.sliding_window: - last_valid_pos_idx = decode_inputs["position_ids"][0][0] - first_valid_pos_idx = last_valid_pos_idx - config.sliding_window - k = qpc_out[f"past_key.{i}_RetainedState"][:, :, first_valid_pos_idx:last_valid_pos_idx, :] - v = qpc_out[f"past_value.{i}_RetainedState"][:, :, first_valid_pos_idx:last_valid_pos_idx, :] - mod_pos_id = config.sliding_window - decode_inputs["position_ids"][0][0] % config.sliding_window - decode_inputs[f"past_key.{i}"] = np.concatenate((k[:, :, mod_pos_id:, :], k[:, :, :mod_pos_id, :]), axis=-2) - decode_inputs[f"past_value.{i}"] = np.concatenate((v[:, :, mod_pos_id:, :], v[:, :, :mod_pos_id, :]), axis=-2) - else: - decode_inputs[f"past_key.{i}"] = qpc_out[f"past_key.{i}_RetainedState"] - decode_inputs[f"past_value.{i}"] = qpc_out[f"past_value.{i}_RetainedState"] + decode_inputs[f"past_key.{i}"] = qpc_out[f"past_key.{i}_RetainedState"] + decode_inputs[f"past_value.{i}"] = qpc_out[f"past_value.{i}_RetainedState"] st = time.time() decode_out = decode_session.run(decode_inputs) print(f"time for first run of decode with KV as input = {time.time() - st} sec\n") -decode_session.skip_buffers( - [x for x in decode_session.input_names + decode_session.output_names if x.startswith("past_")] -) +all_outputs.append(np.argmax(decode_out["logits"])) pos_id = np.max(decode_inputs["position_ids"]).reshape(1, 1) + 1 +loop_decode_inputs = { + "input_ids": np.argmax(decode_out["logits"]).reshape(1, 1), + "position_ids": pos_id, +} + +for i in range(config.num_hidden_layers): + loop_decode_inputs[f"past_key.{i}"] = decode_out[f"past_key.{i}_RetainedState"] + loop_decode_inputs[f"past_value.{i}"] = decode_out[f"past_value.{i}_RetainedState"] + st = time.time() for i in range(generation_len - 2): - loop_decode_inputs = { - "input_ids": np.argmax(decode_out["logits"]).reshape(1, 1), - "position_ids": pos_id, - } - all_outputs.append(loop_decode_inputs["input_ids"][0][0]) decode_out = decode_session.run(loop_decode_inputs) + all_outputs.append(np.argmax(decode_out["logits"])) pos_id += 1 - - -print(f"time for decode generation = {(time.time() - st) / (generation_len - 2)}") -print(all_outputs) -print(tokenizer.decode(all_outputs)) + for i in range(config.num_hidden_layers): + loop_decode_inputs[f"past_key.{i}"] = decode_out[f"past_key.{i}_RetainedState"] + loop_decode_inputs[f"past_value.{i}"] = decode_out[f"past_value.{i}_RetainedState"] + + loop_decode_inputs.update( + { + "input_ids": np.argmax(decode_out["logits"]).reshape(1, 1), + "position_ids": pos_id, + } + ) +ft = time.time() + +print(f"decode tok/sec={(generation_len - 2) / (ft - st)}") +print(f"input\n{prompt}\noutput\n{tokenizer.decode(all_outputs)}") diff --git a/tests/transformers/models/test_disagg_mode.py b/tests/transformers/models/test_disagg_mode.py index 8851365e2..6358940df 100644 --- a/tests/transformers/models/test_disagg_mode.py +++ b/tests/transformers/models/test_disagg_mode.py @@ -16,7 +16,7 @@ from QEfficient.generation.cloud_infer import QAICInferenceSession from QEfficient.transformers.quantizers import replace_transformers_quantizers, undo_transformers_quantizers -model_id = "openai/gpt-oss-20b" # weights are not required to convert to fp32 +model_id = "openai/gpt-oss-120b" # weights are not required to convert to fp32 prompt2 = """ Once upon a time, in a small town, there lived a young boy named Alex. Alex was a curious and adventurous child, always eager to explore the world around him. One day, while playing in the park, Alex stumbled upon a mysterious old book hidden beneath a pile of leaves. The book was filled with stories of distant lands, magical creatures, and extraordinary adventures. @@ -104,6 +104,7 @@ def test_disagg_mode_prefill(model_id, prompt): assert (torch.from_numpy(qpc_out["logits"]) - qeff_out.logits).abs().max() < 5e-2 +@pytest.mark.skip(reason="no way of currently testing this without the assert sdk") @pytest.mark.on_qaic @pytest.mark.parametrize("model_id", [model_id]) @pytest.mark.parametrize("prompt", prompts) From 1721b9d998b257912000081945b39017eaf42de6 Mon Sep 17 00:00:00 2001 From: Onkar Chougule Date: Mon, 8 Dec 2025 15:03:35 +0000 Subject: [PATCH 35/37] added support for prefix caching for gpt-oss Signed-off-by: Onkar Chougule --- QEfficient/base/modeling_qeff.py | 4 +- QEfficient/base/onnx_transforms.py | 4 +- QEfficient/customop/ctx_scatter_gather_cb.py | 1 + QEfficient/transformers/cache_utils.py | 56 ++++++++++++++----- .../models/gpt_oss/modeling_gpt_oss.py | 20 +++---- .../transformers/models/modeling_auto.py | 25 ++++----- 6 files changed, 68 insertions(+), 42 deletions(-) diff --git a/QEfficient/base/modeling_qeff.py b/QEfficient/base/modeling_qeff.py index af9cebf79..71fc40c7c 100644 --- a/QEfficient/base/modeling_qeff.py +++ b/QEfficient/base/modeling_qeff.py @@ -520,8 +520,8 @@ class FeatureNotAvailableError(Exception): exec_command = f'QAIC_COMPILER_OPTS_UNSUPPORTED="-loader-inline-all=0" {" ".join(command)}' raise FeatureNotAvailableError( - f"ONNX graph is exported with subfunctions, assert version of apps SDK should be used for compiling this model. \ - Run following command manually with assert compiler:\n{exec_command}" + "ONNX graph is exported with subfunctions, assert version of apps SDK should be used for compiling this model." + + f"\nRun following command manually with assert compiler:\n{exec_command}" ) try: subprocess.run(command, capture_output=True, check=True) diff --git a/QEfficient/base/onnx_transforms.py b/QEfficient/base/onnx_transforms.py index 945850c50..081025b40 100644 --- a/QEfficient/base/onnx_transforms.py +++ b/QEfficient/base/onnx_transforms.py @@ -91,10 +91,10 @@ class CustomOpTransform(BaseOnnxTransform): "CtxScatterFunc3D": (CtxScatterFunc3D, CtxScatter3D), "CtxGatherFunc": (CtxGatherFunc, CtxGather), "CtxGatherFunc3D": (CtxGatherFunc3D, CtxGather3D), - "CtxScatterFuncCB": (CtxScatterFuncCB, CtxScatterCB), "CtxScatterFuncCB3D": (CtxScatterFuncCB3D, CtxScatterCB3D), - "CtxGatherFuncCB": (CtxGatherFuncCB, CtxGatherCB), "CtxGatherFuncCB3D": (CtxGatherFuncCB3D, CtxGatherCB3D), + "CtxScatterFuncCB": (CtxScatterFuncCB, CtxScatterCB), + "CtxGatherFuncCB": (CtxGatherFuncCB, CtxGatherCB), } @classmethod diff --git a/QEfficient/customop/ctx_scatter_gather_cb.py b/QEfficient/customop/ctx_scatter_gather_cb.py index 8a06bc2b1..c15b60810 100644 --- a/QEfficient/customop/ctx_scatter_gather_cb.py +++ b/QEfficient/customop/ctx_scatter_gather_cb.py @@ -126,6 +126,7 @@ class CtxGatherFuncCB(torch.autograd.Function): def forward(data: torch.Tensor, batch_index: torch.Tensor, ctx_indices: torch.Tensor, comp_ctx_len: int): batch_indices = batch_index.view(-1, 1, 1) head_indices = torch.arange(data.shape[1]).view(1, -1, 1) + ctx_indices = torch.where(ctx_indices >= data.shape[2], 0, ctx_indices) return data[batch_indices, head_indices, ctx_indices] @staticmethod diff --git a/QEfficient/transformers/cache_utils.py b/QEfficient/transformers/cache_utils.py index fd6a81691..faadaba6b 100644 --- a/QEfficient/transformers/cache_utils.py +++ b/QEfficient/transformers/cache_utils.py @@ -46,6 +46,7 @@ def _get_invalid_idx_value(cls): """ if torch.onnx.is_in_onnx_export(): if cls.SUBFUNC_ENABLED: + # TODO: should not return 0 remove this if condition, it can hurt perf return 0 else: return torch.iinfo(torch.int32).max @@ -787,9 +788,22 @@ def full_cache_update_chunked( cache_kwargs: Optional[Dict[str, Any]] = None, ) -> Tuple[torch.Tensor, torch.Tensor]: position_ids = cache_kwargs.get("position_ids") + batch_index = cache_kwargs.get("batch_index") + invalid_idx_value = InvalidIndexProvider._get_invalid_idx_value() - self.key_cache[layer_idx] = CtxScatterFunc.apply(self.key_cache[layer_idx], position_ids, key_states) - self.value_cache[layer_idx] = CtxScatterFunc.apply(self.value_cache[layer_idx], position_ids, value_states) + # Scatter + if batch_index is not None: + if torch.onnx.is_in_onnx_export(): + scatter_position_ids = torch.where(position_ids < 0, torch.iinfo(torch.int32).max, position_ids) + self.key_cache[layer_idx] = CtxScatterFuncCB.apply( + self.key_cache[layer_idx], batch_index, scatter_position_ids, key_states + ) + self.value_cache[layer_idx] = CtxScatterFuncCB.apply( + self.value_cache[layer_idx], batch_index, scatter_position_ids, value_states + ) + else: + self.key_cache[layer_idx] = CtxScatterFunc.apply(self.key_cache[layer_idx], position_ids, key_states) + self.value_cache[layer_idx] = CtxScatterFunc.apply(self.value_cache[layer_idx], position_ids, value_states) k_out, v_out = self.key_cache[layer_idx], self.value_cache[layer_idx] @@ -798,11 +812,13 @@ def full_cache_update_chunked( ctx_indices = torch.arange(ctx_len)[None, None, ...] gather_limit = position_ids.max(1, keepdim=True).values.unsqueeze(1) invalid_mask = ctx_indices > gather_limit - - invalid_idx_value = InvalidIndexProvider._get_invalid_idx_value() ctx_indices = torch.where(invalid_mask, invalid_idx_value, ctx_indices) - k_out = CtxGatherFunc.apply(k_out, ctx_indices, ctx_len) - v_out = CtxGatherFunc.apply(v_out, ctx_indices, ctx_len) + if batch_index is not None: + k_out = CtxGatherFuncCB.apply(k_out, batch_index, ctx_indices, ctx_len) + v_out = CtxGatherFuncCB.apply(v_out, batch_index, ctx_indices, ctx_len) + else: + k_out = CtxGatherFunc.apply(k_out, ctx_indices, ctx_len) + v_out = CtxGatherFunc.apply(v_out, ctx_indices, ctx_len) v_out = torch.where(invalid_mask.unsqueeze(-1), torch.tensor(0.0, dtype=torch.float32), v_out) return k_out, v_out @@ -815,26 +831,40 @@ def sliding_window_update_chunked( cache_kwargs: Optional[Dict[str, Any]] = None, ) -> Tuple[torch.Tensor, torch.Tensor]: position_ids = cache_kwargs.get("position_ids") + batch_index = cache_kwargs.get("batch_index") + invalid_idx_value = InvalidIndexProvider._get_invalid_idx_value() - self.key_cache[layer_idx] = CtxScatterFunc.apply(self.key_cache[layer_idx], position_ids, key_states) - self.value_cache[layer_idx] = CtxScatterFunc.apply(self.value_cache[layer_idx], position_ids, value_states) + if batch_index is not None: + if torch.onnx.is_in_onnx_export(): + scatter_position_ids = torch.where(position_ids < 0, torch.iinfo(torch.int32).max, position_ids) + self.key_cache[layer_idx] = CtxScatterFuncCB.apply( + self.key_cache[layer_idx], batch_index, scatter_position_ids, key_states + ) + self.value_cache[layer_idx] = CtxScatterFuncCB.apply( + self.value_cache[layer_idx], batch_index, scatter_position_ids, value_states + ) + else: + self.key_cache[layer_idx] = CtxScatterFunc.apply(self.key_cache[layer_idx], position_ids, key_states) + self.value_cache[layer_idx] = CtxScatterFunc.apply(self.value_cache[layer_idx], position_ids, value_states) k_out, v_out = self.key_cache[layer_idx], self.value_cache[layer_idx] sliding_window_len = cache_kwargs.get("sliding_window") + # Gather ctx_len = position_ids.shape[1] + sliding_window_len ctx_indices = torch.arange(ctx_len)[None, None, ...] - # positive_pos_ids = torch.where(position_ids<0, 0, position_ids) first_pos_idx = position_ids[0][0] add_idx = torch.where(first_pos_idx >= sliding_window_len, first_pos_idx - sliding_window_len, 0) ctx_indices += add_idx gather_limit = position_ids.max(1, keepdim=True).values.unsqueeze(1) invalid_mask = ctx_indices > gather_limit - - invalid_idx_value = InvalidIndexProvider._get_invalid_idx_value() ctx_indices = torch.where(invalid_mask, invalid_idx_value, ctx_indices) - k_out = CtxGatherFunc.apply(k_out, ctx_indices, ctx_len) - v_out = CtxGatherFunc.apply(v_out, ctx_indices, ctx_len) + if batch_index is not None: + k_out = CtxGatherFuncCB.apply(k_out, batch_index, ctx_indices, ctx_len) + v_out = CtxGatherFuncCB.apply(v_out, batch_index, ctx_indices, ctx_len) + else: + k_out = CtxGatherFunc.apply(k_out, ctx_indices, ctx_len) + v_out = CtxGatherFunc.apply(v_out, ctx_indices, ctx_len) v_out = torch.where(invalid_mask.unsqueeze(-1), torch.tensor(0.0, dtype=torch.float32), v_out) return k_out, v_out diff --git a/QEfficient/transformers/models/gpt_oss/modeling_gpt_oss.py b/QEfficient/transformers/models/gpt_oss/modeling_gpt_oss.py index 309eab7af..95ef69f04 100644 --- a/QEfficient/transformers/models/gpt_oss/modeling_gpt_oss.py +++ b/QEfficient/transformers/models/gpt_oss/modeling_gpt_oss.py @@ -92,8 +92,7 @@ def forward(self, hidden: torch.Tensor): down_out = (intermediate @ W_d) + b_d # [T, H] # Apply routing weights and accumulate - masked_down = torch.where(routing_weight > 0, down_out * routing_weight, torch.zeros_like(expert_out)) - expert_out += masked_down + expert_out += down_out * routing_weight # original shape [B, S, H] return expert_out.view(B, S, H), router_logits @@ -148,8 +147,7 @@ def forward(self, hidden: torch.Tensor): down_out = (intermediate @ W_d) + b_d # [T, H] # Apply routing weights and accumulate - masked_down = torch.where(routing_weight > 0, down_out * routing_weight, torch.zeros_like(expert_out)) - expert_out += masked_down + expert_out += down_out * routing_weight # original shape [B, S, H] return expert_out.view(B, S, H), router_logits @@ -221,8 +219,7 @@ def blocked_ffn_forward(self, hidden: torch.Tensor): down_out = torch.cat(outs, dim=0) # Apply routing weights and accumulate - masked_down = torch.where(routing_weight > 0, down_out * routing_weight, torch.zeros_like(expert_out)) - expert_out += masked_down + expert_out += down_out * routing_weight # original shape [B, S, H] return expert_out.view(B, S, H), router_logits @@ -1296,16 +1293,15 @@ def forward( router_logits=outputs.router_logits, ) - def get_pkv_dynamic_axes( - self, - retain_full_kv: Optional[bool] = False, - ): + def get_pkv_dynamic_axes(self, retain_full_kv: Optional[bool] = False, continuous_batching: Optional[bool] = False): pkv_dynamic_axes = [] for layer_type in self.config.layer_types: if layer_type == "sliding_attention" and not retain_full_kv: - pkv_dynamic_axes.append({0: "batch_size", 2: "sliding_window"}) + pkv_dynamic_axes.append( + {0: "full_batch_size" if continuous_batching else "batch_size", 2: "sliding_window"} + ) else: - pkv_dynamic_axes.append({0: "batch_size", 2: "ctx_len"}) + pkv_dynamic_axes.append({0: "full_batch_size" if continuous_batching else "batch_size", 2: "ctx_len"}) return pkv_dynamic_axes def get_specializations( diff --git a/QEfficient/transformers/models/modeling_auto.py b/QEfficient/transformers/models/modeling_auto.py index 4c0a424b7..0347e2f45 100644 --- a/QEfficient/transformers/models/modeling_auto.py +++ b/QEfficient/transformers/models/modeling_auto.py @@ -2648,7 +2648,6 @@ def export( self.model.config, fbs if self.continuous_batching else bs, seq_len ) if prefill_only: - assert not self.continuous_batching, "prefill_only=True is not supported with continuous_batching=True" self.prefill(enable=True, enable_chunking=kwargs.get("enable_chunking", False)) self.hash_params.pop("retain_full_kv", None) seq_len = ( @@ -2722,7 +2721,8 @@ def export( pkv_dynamic_axes = ( self.model.get_pkv_dynamic_axes( retain_full_kv=kwargs.get("retain_full_kv", False) - or (prefill_only and kwargs.get("enable_chunking", False)) + or (prefill_only and kwargs.get("enable_chunking", False)), + continuous_batching=self.continuous_batching, ) if hasattr(self.model, "get_pkv_dynamic_axes") else pkv_dynamic_axes @@ -2734,7 +2734,6 @@ def export( ) for i in range(self.num_layers): - pkv_dynamic_axes[i][0] = "full_batch_size" if self.continuous_batching else "batch_size" for kv in ["key", "value"]: example_inputs["past_key_values"][i].append(torch.zeros(kv_cache_shape, dtype=torch.float32)) dynamic_axes[f"past_{kv}.{i}"] = pkv_dynamic_axes[i] @@ -3099,15 +3098,6 @@ def compile( if self.is_tlm: num_speculative_tokens = self.check_and_get_num_speculative_tokens(num_speculative_tokens, prefill_seq_len) - if self.continuous_batching and full_batch_size is None: - raise TypeError("`full_batch_size` is required when `continuous_batching=True`.") - - if kv_cache_batch_size and not full_batch_size: - raise ValueError( - "KV caching requires continuous batching. Please set `full_batch_size` and " - "enable `continuous_batching=True` in `from_pretrained`." - ) - if ( self.model.qaic_config is not None and self.model.qaic_config.get("include_sampler", False) @@ -3117,7 +3107,9 @@ def compile( raise ValueError("Currently, sampler does not support `num_speculative_tokens` > 0.") if kv_cache_batch_size and prefill_only is not None and prefill_only: - logger.warning("kv_cache_batch_size will be ignored as prefill_only is set to True") + logger.warning( + "kv_cache_batch_size will be ignored as prefill_only is set to True unless this is GPTOSS model" + ) # Infer kv_cache_batch_size if not provided kv_cache_batch_size = kv_cache_batch_size or full_batch_size or batch_size @@ -3155,6 +3147,13 @@ def compile( ) if prefill_only is None or not prefill_only: + if self.continuous_batching and full_batch_size is None: + raise TypeError("`full_batch_size` is required when `continuous_batching=True`.") + if kv_cache_batch_size and not full_batch_size: + raise ValueError( + "KV caching requires continuous batching. Please set `full_batch_size` and " + "enable `continuous_batching=True` in `from_pretrained`." + ) if self.comp_ctx_lengths_decode is not None: # Adding elements from self.comp_ctx_lengths_decode to decode_specialization for i in range(0, len(self.comp_ctx_lengths_decode)): From 1b60a5f3f489ac1828a931b8e005de2726ca5a7c Mon Sep 17 00:00:00 2001 From: Onkar Chougule Date: Tue, 9 Dec 2025 12:25:33 +0000 Subject: [PATCH 36/37] removed error Signed-off-by: Onkar Chougule --- QEfficient/transformers/models/modeling_auto.py | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/QEfficient/transformers/models/modeling_auto.py b/QEfficient/transformers/models/modeling_auto.py index 0347e2f45..6e5c5aa53 100644 --- a/QEfficient/transformers/models/modeling_auto.py +++ b/QEfficient/transformers/models/modeling_auto.py @@ -3061,6 +3061,14 @@ def compile( If `prefill_seq_len` is less than `num_speculative_tokens + 1` for TLM models. """ + if prefill_only is None or not prefill_only: + if self.continuous_batching and full_batch_size is None: + raise TypeError("`full_batch_size` is required when `continuous_batching=True`.") + if kv_cache_batch_size and not full_batch_size: + raise ValueError( + "KV caching requires continuous batching. Please set `full_batch_size` and " + "enable `continuous_batching=True` in `from_pretrained`." + ) # if ccl_enabled is True read Compute-Context-Length lists if self.ccl_enabled: @@ -3147,13 +3155,6 @@ def compile( ) if prefill_only is None or not prefill_only: - if self.continuous_batching and full_batch_size is None: - raise TypeError("`full_batch_size` is required when `continuous_batching=True`.") - if kv_cache_batch_size and not full_batch_size: - raise ValueError( - "KV caching requires continuous batching. Please set `full_batch_size` and " - "enable `continuous_batching=True` in `from_pretrained`." - ) if self.comp_ctx_lengths_decode is not None: # Adding elements from self.comp_ctx_lengths_decode to decode_specialization for i in range(0, len(self.comp_ctx_lengths_decode)): From e8d11288c3ca26c3dda3bc12fa90e5b85701365e Mon Sep 17 00:00:00 2001 From: Onkar Chougule Date: Tue, 9 Dec 2025 12:39:24 +0000 Subject: [PATCH 37/37] added errors for prefill-only mode Signed-off-by: Onkar Chougule --- QEfficient/transformers/models/modeling_auto.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/QEfficient/transformers/models/modeling_auto.py b/QEfficient/transformers/models/modeling_auto.py index 6e5c5aa53..955abcd85 100644 --- a/QEfficient/transformers/models/modeling_auto.py +++ b/QEfficient/transformers/models/modeling_auto.py @@ -3069,6 +3069,16 @@ def compile( "KV caching requires continuous batching. Please set `full_batch_size` and " "enable `continuous_batching=True` in `from_pretrained`." ) + else: + if self.continuous_batching: + if not enable_chunking: + raise NotImplementedError( + "Looks like you are trying to run prefix-caching without chunking, this feature is not available yet!" + ) + if not isinstance(kv_cache_batch_size, int): + raise ValueError( + "Please pass valid integer for kv_cache_batch_size as continuous_batching is enabled for prefill-only model" + ) # if ccl_enabled is True read Compute-Context-Length lists if self.ccl_enabled: