Skip to content

Commit 0912c39

Browse files
committed
Adding ccl_enabled flag during model loading and passing CCL lists during compilation process
Signed-off-by: Vahid Janfaza <vjanfaza@qti.qualcomm.com>
1 parent 30c334b commit 0912c39

16 files changed

+263
-166
lines changed

QEfficient/transformers/models/modeling_auto.py

Lines changed: 67 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -909,7 +909,7 @@ def __init__(
909909
self,
910910
model: nn.Module,
911911
continuous_batching: bool = False,
912-
qaic_config: Optional[dict] = None,
912+
ccl_enabled: bool = False,
913913
**kwargs,
914914
):
915915
"""
@@ -932,11 +932,10 @@ def __init__(
932932
self.model = model
933933
self.config = model.config
934934

935-
self.comp_ctx_lengths_prefill, self.comp_ctx_lengths_decode = process_ccl_specializations(qaic_config)
936-
937935
self.vision_model = QEffVisionEncoderForTextImageToTextModel(model, **kwargs)
938936
self.lang_model = QEffCausalLMForTextImageToTextModel(model, **kwargs)
939937
self.continuous_batching = continuous_batching
938+
self.ccl_enabled = ccl_enabled
940939
self.input_shapes, self.output_names = None, None
941940

942941
@property
@@ -955,7 +954,7 @@ def model_name(self) -> str:
955954
return mname
956955

957956
@classmethod
958-
def from_pretrained(cls, pretrained_model_name_or_path: str, qaic_config: Optional[dict] = None, **kwargs):
957+
def from_pretrained(cls, pretrained_model_name_or_path: str, **kwargs):
959958
"""
960959
Load a QEfficient multimodal model for dual QPC from a pretrained HuggingFace model or local path.
961960
@@ -980,11 +979,13 @@ def from_pretrained(cls, pretrained_model_name_or_path: str, qaic_config: Option
980979
logger.warning("Updating low_cpu_mem_usage=False")
981980

982981
kwargs.update({"attn_implementation": "eager", "low_cpu_mem_usage": False})
982+
ccl_enabled = kwargs.pop("ccl_enabled", None)
983+
983984
model = cls._hf_auto_class.from_pretrained(pretrained_model_name_or_path, **kwargs)
984985
return cls(
985986
model,
986987
pretrained_model_name_or_path=pretrained_model_name_or_path,
987-
qaic_config=qaic_config,
988+
ccl_enabled=ccl_enabled,
988989
**kwargs,
989990
)
990991

@@ -1090,6 +1091,8 @@ def compile(
10901091
compile_dir: Optional[str] = None,
10911092
*,
10921093
prefill_seq_len: Optional[int] = None,
1094+
comp_ctx_lengths_prefill: Optional[List[int]] = None,
1095+
comp_ctx_lengths_decode: Optional[List[int]] = None,
10931096
ctx_len: Optional[int] = None,
10941097
batch_size: int = 1,
10951098
full_batch_size: Optional[int] = None,
@@ -1174,10 +1177,21 @@ def compile(
11741177

11751178
output_names = self.model.get_output_names(kv_offload=True)
11761179

1180+
# if ccl_enabled is True read Compute-Context-Length lists
1181+
self.comp_ctx_lengths_prefill, self.comp_ctx_lengths_decode = None, None
1182+
if self.ccl_enabled:
1183+
if comp_ctx_lengths_prefill is None or comp_ctx_lengths_decode is None:
1184+
logger.warning(
1185+
"Please set comp_ctx_lengths_prefill and comp_ctx_lengths_decode with a proper list of context lengths. Using non-CCL default model."
1186+
)
1187+
self.comp_ctx_lengths_prefill, self.comp_ctx_lengths_decode = process_ccl_specializations(
1188+
comp_ctx_lengths_prefill, comp_ctx_lengths_decode, ctx_len, prefill_seq_len
1189+
)
1190+
11771191
# For supporting VLLM and Disaggregated with CCL
1178-
if "comp_ctx_lengths_prefill" in compiler_options:
1179-
self.comp_ctx_lengths_prefill = compiler_options.pop("comp_ctx_lengths_prefill")
1180-
self.comp_ctx_lengths_decode = compiler_options.pop("comp_ctx_lengths_decode")
1192+
if comp_ctx_lengths_prefill is not None or comp_ctx_lengths_decode is not None:
1193+
self.comp_ctx_lengths_prefill = comp_ctx_lengths_prefill
1194+
self.comp_ctx_lengths_decode = comp_ctx_lengths_decode
11811195

11821196
specializations, compiler_options = self.model.get_specializations(
11831197
batch_size=batch_size,
@@ -1600,7 +1614,7 @@ class _QEFFAutoModelForImageTextToTextSingleQPC(QEFFTransformersBase, Multimodal
16001614
def __init__(
16011615
self,
16021616
model: nn.Module,
1603-
qaic_config: Optional[dict] = None,
1617+
ccl_enabled: bool = False,
16041618
**kwargs,
16051619
):
16061620
"""
@@ -1622,8 +1636,6 @@ def __init__(
16221636
raise NotImplementedError("Continuous batching is not supported for image-text-to-text models yet.")
16231637
super().__init__(model, **kwargs)
16241638

1625-
self.comp_ctx_lengths_prefill, self.comp_ctx_lengths_decode = process_ccl_specializations(qaic_config)
1626-
16271639
# to handle internvl models
16281640
if hasattr(self.model.config, "llm_config") and hasattr(self.model.config, "vision_config"):
16291641
self.model.config.llm_config.use_cache = True
@@ -1635,12 +1647,12 @@ def __init__(
16351647
else:
16361648
self.model.config.use_cache = True
16371649
self.hash_params["qeff_auto_class"] = self.__class__.__name__
1650+
self.ccl_enabled = ccl_enabled
16381651

16391652
@classmethod
16401653
def from_pretrained(
16411654
cls,
16421655
pretrained_model_name_or_path,
1643-
qaic_config: Optional[dict] = None,
16441656
*args,
16451657
**kwargs,
16461658
):
@@ -1671,6 +1683,8 @@ def from_pretrained(
16711683
logger.warning("Updating low_cpu_mem_usage=False")
16721684

16731685
kwargs.update({"attn_implementation": "eager", "low_cpu_mem_usage": False})
1686+
ccl_enabled = kwargs.pop("ccl_enabled", None)
1687+
16741688
from transformers import AutoConfig
16751689

16761690
config = AutoConfig.from_pretrained(pretrained_model_name_or_path, trust_remote_code=True)
@@ -1681,7 +1695,7 @@ def from_pretrained(
16811695
return cls(
16821696
model,
16831697
pretrained_model_name_or_path=pretrained_model_name_or_path,
1684-
qaic_config=qaic_config,
1698+
ccl_enabled=ccl_enabled,
16851699
**kwargs,
16861700
)
16871701

@@ -1725,6 +1739,8 @@ def compile(
17251739
*,
17261740
prefill_seq_len: Optional[int] = None,
17271741
ctx_len: Optional[int] = None,
1742+
comp_ctx_lengths_prefill: Optional[List[int]] = None,
1743+
comp_ctx_lengths_decode: Optional[List[int]] = None,
17281744
batch_size: int = 1,
17291745
full_batch_size: Optional[int] = None,
17301746
kv_cache_batch_size: Optional[int] = None,
@@ -1794,10 +1810,21 @@ def compile(
17941810
kv_cache_batch_size = kv_cache_batch_size or full_batch_size or batch_size
17951811
output_names = self.model.get_output_names()
17961812

1813+
# if ccl_enabled is True read Compute-Context-Length lists
1814+
self.comp_ctx_lengths_prefill, self.comp_ctx_lengths_decode = None, None
1815+
if self.ccl_enabled:
1816+
if comp_ctx_lengths_prefill is None or comp_ctx_lengths_decode is None:
1817+
logger.warning(
1818+
"Please set comp_ctx_lengths_prefill and comp_ctx_lengths_decode with a proper list of context lengths. Using non-CCL default model."
1819+
)
1820+
self.comp_ctx_lengths_prefill, self.comp_ctx_lengths_decode = process_ccl_specializations(
1821+
comp_ctx_lengths_prefill, comp_ctx_lengths_decode, ctx_len, prefill_seq_len
1822+
)
1823+
17971824
# For supporting VLLM and Disaggregated with CCL
1798-
if "comp_ctx_lengths_prefill" in compiler_options:
1799-
self.comp_ctx_lengths_prefill = compiler_options.pop("comp_ctx_lengths_prefill")
1800-
self.comp_ctx_lengths_decode = compiler_options.pop("comp_ctx_lengths_decode")
1825+
if comp_ctx_lengths_prefill is not None or comp_ctx_lengths_decode is not None:
1826+
self.comp_ctx_lengths_prefill = comp_ctx_lengths_prefill
1827+
self.comp_ctx_lengths_decode = comp_ctx_lengths_decode
18011828

18021829
# Get specializations from modelling file
18031830
# TODO: expose this via the auto class as well
@@ -2180,7 +2207,7 @@ def __new__(
21802207
model: nn.Module,
21812208
kv_offload: Optional[bool] = True,
21822209
continuous_batching: bool = False,
2183-
qaic_config: Optional[dict] = None,
2210+
ccl_enabled: bool = False,
21842211
**kwargs,
21852212
):
21862213
"""
@@ -2204,10 +2231,10 @@ def __new__(
22042231
"""
22052232
if kv_offload:
22062233
return _QEffAutoModelForImageTextToTextDualQPC(
2207-
model, continuous_batching, qaic_config=qaic_config, **kwargs
2234+
model, continuous_batching, ccl_enabled=ccl_enabled, **kwargs
22082235
)
22092236
else:
2210-
return _QEFFAutoModelForImageTextToTextSingleQPC(model, qaic_config=qaic_config, **kwargs)
2237+
return _QEFFAutoModelForImageTextToTextSingleQPC(model, ccl_enabled=ccl_enabled, **kwargs)
22112238

22122239
@classmethod
22132240
@with_replaced_quantizers
@@ -2257,14 +2284,15 @@ def from_pretrained(
22572284
logger.warning("Updating low_cpu_mem_usage=False")
22582285

22592286
kwargs.update({"attn_implementation": "eager", "low_cpu_mem_usage": False})
2287+
ccl_enabled = kwargs.pop("ccl_enabled", None)
22602288

22612289
model = cls._hf_auto_class.from_pretrained(pretrained_model_name_or_path, **kwargs)
22622290
return cls(
22632291
model,
22642292
kv_offload=kv_offload,
22652293
continuous_batching=continuous_batching,
22662294
pretrained_model_name_or_path=pretrained_model_name_or_path,
2267-
qaic_config=qaic_config,
2295+
ccl_enabled=ccl_enabled,
22682296
**kwargs,
22692297
)
22702298

@@ -2317,6 +2345,7 @@ def __init__(
23172345
model: nn.Module,
23182346
continuous_batching: bool = False,
23192347
qaic_config: Optional[dict] = None,
2348+
ccl_enabled: bool = False,
23202349
**kwargs,
23212350
):
23222351
"""
@@ -2363,8 +2392,6 @@ def __init__(
23632392
# Set use_cache=True to get KV values as output during ONNX export
23642393
model.config.use_cache = True
23652394

2366-
self.comp_ctx_lengths_prefill, self.comp_ctx_lengths_decode = process_ccl_specializations(qaic_config)
2367-
23682395
super().__init__(model, qaic_config=qaic_config, **kwargs)
23692396
self.num_layers = model.config.num_hidden_layers
23702397
self.continuous_batching = continuous_batching
@@ -2373,6 +2400,7 @@ def __init__(
23732400
self.is_tlm = transformed
23742401

23752402
self.hash_params["qeff_auto_class"] = self.__class__.__name__
2403+
self.ccl_enabled = ccl_enabled
23762404

23772405
# ---Sampling---
23782406
# Note: SamplerTransform should be applied after all other transforms
@@ -2465,6 +2493,7 @@ def from_pretrained(
24652493
logger.warning("Updating low_cpu_mem_usage=False")
24662494

24672495
kv_offload = kwargs.pop("kv_offload", None)
2496+
ccl_enabled = kwargs.pop("ccl_enabled", None)
24682497

24692498
kwargs.update({"attn_implementation": "eager", "low_cpu_mem_usage": False})
24702499
model = cls._hf_auto_class.from_pretrained(pretrained_model_name_or_path, *args, **kwargs)
@@ -2478,14 +2507,15 @@ def from_pretrained(
24782507
model,
24792508
kv_offload=kv_offload,
24802509
pretrained_model_name_or_path=pretrained_model_name_or_path,
2481-
qaic_config=qaic_config,
2510+
ccl_enabled=ccl_enabled,
24822511
**kwargs,
24832512
)
24842513
return cls(
24852514
model,
24862515
continuous_batching=continuous_batching,
24872516
qaic_config=qaic_config,
24882517
pretrained_model_name_or_path=pretrained_model_name_or_path,
2518+
ccl_enabled=ccl_enabled,
24892519
**kwargs,
24902520
)
24912521

@@ -2814,6 +2844,8 @@ def compile(
28142844
*,
28152845
prefill_seq_len: int = 32,
28162846
ctx_len: int = 128,
2847+
comp_ctx_lengths_prefill: Optional[List[int]] = None,
2848+
comp_ctx_lengths_decode: Optional[List[int]] = None,
28172849
batch_size: int = 1,
28182850
full_batch_size: Optional[int] = None,
28192851
kv_cache_batch_size: Optional[int] = None,
@@ -2905,10 +2937,19 @@ def compile(
29052937
29062938
"""
29072939

2940+
# if ccl_enabled is True read Compute-Context-Length lists
2941+
self.comp_ctx_lengths_prefill, self.comp_ctx_lengths_decode = None, None
2942+
if self.ccl_enabled:
2943+
if comp_ctx_lengths_prefill is None or comp_ctx_lengths_decode is None:
2944+
logger.warning(
2945+
"Please set comp_ctx_lengths_prefill and comp_ctx_lengths_decode with a proper list of context lengths. Using non-CCL default model."
2946+
)
2947+
self.comp_ctx_lengths_prefill, self.comp_ctx_lengths_decode = process_ccl_specializations(
2948+
comp_ctx_lengths_prefill, comp_ctx_lengths_decode, ctx_len, prefill_seq_len
2949+
)
2950+
29082951
# For supporting VLLM and Disaggregated with CCL
2909-
if "comp_ctx_lengths_prefill" in compiler_options and "comp_ctx_lengths_decode" in compiler_options:
2910-
comp_ctx_lengths_prefill = compiler_options.pop("comp_ctx_lengths_prefill")
2911-
comp_ctx_lengths_decode = compiler_options.pop("comp_ctx_lengths_decode")
2952+
if comp_ctx_lengths_prefill is not None or comp_ctx_lengths_decode is not None:
29122953
if isinstance(comp_ctx_lengths_prefill, str):
29132954
import ast
29142955

QEfficient/utils/check_ccl_specializations.py

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -6,14 +6,7 @@
66
# -----------------------------------------------------------------------------
77

88

9-
def process_ccl_specializations(qaic_config):
10-
if qaic_config is None:
11-
return None, None
12-
ccl_prefill = qaic_config.pop("comp_ctx_lengths_prefill", None)
13-
ccl_decode = qaic_config.pop("comp_ctx_lengths_decode", None)
14-
ctx_len = qaic_config.pop("ctx_len", None)
15-
prefill_seq_len = qaic_config.pop("prefill_seq_len", 128)
16-
9+
def process_ccl_specializations(ccl_prefill, ccl_decode, ctx_len, prefill_seq_len):
1710
if ccl_prefill is None or ccl_decode is None:
1811
return None, None
1912

examples/ccl_gpt_oss.py

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -11,26 +11,28 @@
1111

1212
model_id = "openai/gpt-oss-20b" # weights are not required to convert to fp32
1313

14+
## Activate Compute-Context-Length (CCL) feature by setting ccl_enabled=True when loading the model with from_pretrained().
15+
## Use the optional comp_ctx_lengths argument to provide two lists of context lengths for the prefilling and decoding processes. If comp_ctx_lengths=None, the model will run with its default context length.
16+
## - The first list, comp_ctx_lengths_prefill, defines the compute-context-length values for the prefilling process.
17+
## -- The process starts with the first value in the list and gradually increases the context length based on the position_id of the current prompt chunk.
18+
## - The second list, comp_ctx_lengths_decode, defines the compute-context-length values for the decoding process.
19+
## -- During decoding, the model selects an appropriate context length from the list based on the input prompt length and cache index.
20+
## -- It starts from the correct value in the list and increases the context length dynamically when the cache index exceeds the current threshold.
21+
1422
ctx_len = 4096
1523
# In moe models like gpt-oss, since prefill_seq_len=1 both comp_ctx_lengths_prefill and comp_ctx_lengths_decode can share similar lists.
1624
# Set the list of ccl during prefilling process
17-
comp_ctx_lengths_prefill = [512, ctx_len]
25+
comp_ctx_lengths_prefill = [512, ctx_len] #None #
1826
# Set the list of ccl during decoding process
19-
comp_ctx_lengths_decode = [512, ctx_len]
27+
comp_ctx_lengths_decode = [512, ctx_len] #None #
2028

2129

2230
qeff_model = QEFFAutoModelForCausalLM.from_pretrained(
2331
model_id,
24-
qaic_config={
25-
"comp_ctx_lengths_prefill": comp_ctx_lengths_prefill,
26-
"comp_ctx_lengths_decode": comp_ctx_lengths_decode,
27-
"ctx_len": ctx_len,
28-
"prefill_seq_len": 1, # Passing prefill_seq_len is mandatory for CCL goal in moe models. Currently we can get best perf using PL=1.
29-
},
32+
ccl_enabled=True,
3033
)
3134
tokenizer = AutoTokenizer.from_pretrained(model_id)
3235

33-
onnx_model_path = qeff_model.export()
3436
qpc_path = qeff_model.compile(
3537
prefill_seq_len=1, # Currently we can get best perf using PL=1 i.e. decode-only model, prefill optimizations are being worked on.
3638
ctx_len=ctx_len,
@@ -41,6 +43,8 @@
4143
mos=1,
4244
aic_enable_depth_first=True,
4345
num_speculative_tokens=None,
46+
comp_ctx_lengths_prefill=comp_ctx_lengths_prefill,
47+
comp_ctx_lengths_decode=comp_ctx_lengths_decode,
4448
)
4549
print(f"qpc path is {qpc_path}")
4650
streamer = TextStreamer(tokenizer)

0 commit comments

Comments
 (0)