From 1097c8fbd5c87c8e9879441796f2c095066f45a6 Mon Sep 17 00:00:00 2001 From: Vahid Janfaza Date: Wed, 4 Mar 2026 16:35:44 -0800 Subject: [PATCH] Adding the support of CCL to the Prefilling of Disaggregated Serving Signed-off-by: Vahid Janfaza --- .../transformers/models/modeling_auto.py | 8 ++-- QEfficient/utils/check_ccl_specializations.py | 42 ++++++++++--------- QEfficient/utils/constants.py | 1 + 3 files changed, 28 insertions(+), 23 deletions(-) diff --git a/QEfficient/transformers/models/modeling_auto.py b/QEfficient/transformers/models/modeling_auto.py index 3a47aa5ff..8878e0a2b 100644 --- a/QEfficient/transformers/models/modeling_auto.py +++ b/QEfficient/transformers/models/modeling_auto.py @@ -3288,7 +3288,7 @@ def compile( if comp_ctx_lengths_prefill is None and comp_ctx_lengths_decode is None: logger.info("Auto-generating CCL-prefill and CCL-decode lists based on Context Length (CL).") self.comp_ctx_lengths_prefill, self.comp_ctx_lengths_decode, ctx_len = process_ccl_specializations( - comp_ctx_lengths_prefill, comp_ctx_lengths_decode, ctx_len, prefill_seq_len + comp_ctx_lengths_prefill, comp_ctx_lengths_decode, ctx_len, prefill_seq_len, enable_chunking ) # For supporting VLLM and Disaggregated with CCL elif comp_ctx_lengths_prefill is not None or comp_ctx_lengths_decode is not None: @@ -3308,7 +3308,7 @@ def compile( self.comp_ctx_lengths_decode = comp_ctx_lengths_decode self.comp_ctx_lengths_prefill, self.comp_ctx_lengths_decode, ctx_len = process_ccl_specializations( - self.comp_ctx_lengths_prefill, self.comp_ctx_lengths_decode, ctx_len, prefill_seq_len + self.comp_ctx_lengths_prefill, self.comp_ctx_lengths_decode, ctx_len, prefill_seq_len, enable_chunking ) # --- Validation --- if prefill_only is not None and not isinstance(prefill_only, bool): @@ -3333,8 +3333,8 @@ def compile( ccl_lengths = self.comp_ctx_lengths_decode if prefill_seq_len == 1 else self.comp_ctx_lengths_prefill # Adding elements from self.comp_ctx_lengths_prefill to prefill_specialization for i in range(0, len(ccl_lengths)): - if prefill_only or enable_chunking: - raise NotImplementedError("prefill_only or enable_chunking is not supported with CCL") + # if prefill_only or enable_chunking: + # raise NotImplementedError("prefill_only or enable_chunking is not supported with CCL") specializations.append( self.build_prefill_specialization( prefill_seq_len=prefill_seq_len, diff --git a/QEfficient/utils/check_ccl_specializations.py b/QEfficient/utils/check_ccl_specializations.py index 368fde831..b2f0ff9e7 100644 --- a/QEfficient/utils/check_ccl_specializations.py +++ b/QEfficient/utils/check_ccl_specializations.py @@ -132,14 +132,15 @@ def validate_ccl_lists(ccl_prefill, ccl_decode, ctx_len, prefill_seq_len): # Check CCL values are not negative and more than the CCL minimum context length = constants.CCL_MIN_CTX_LEN if ccl_prefill: ccl_prefill = [x if x >= constants.CCL_MIN_CTX_LEN else constants.CCL_MIN_CTX_LEN for x in ccl_prefill] + # Check the last element of ccl_prefill and ccl_decode to make sure it's not less than ctx_len + if ccl_prefill[-1] < ctx_len: + ccl_prefill.append(ctx_len) + if ccl_decode: ccl_decode = [x if x >= constants.CCL_MIN_CTX_LEN else constants.CCL_MIN_CTX_LEN for x in ccl_decode] - - # Check the last element of ccl_prefill and ccl_decode to make sure it's not less than ctx_len - if ccl_prefill[-1] < ctx_len - 1: - ccl_prefill.append(ctx_len) - if ccl_decode[-1] < ctx_len: - ccl_decode.append(ctx_len) + # Check the last element of ccl_prefill and ccl_decode to make sure it's not less than ctx_len + if ccl_decode[-1] < ctx_len: + ccl_decode.append(ctx_len) if prefill_seq_len == 1: # both prefill and decode ccl can share the same specializations since prefill_seq_len=1. So, a sorted union of both lists can be used for both of them. @@ -153,22 +154,25 @@ def validate_ccl_lists(ccl_prefill, ccl_decode, ctx_len, prefill_seq_len): if ccl_decode: ccl_decode = sorted({min(x, ctx_len) for x in (ccl_decode)}) - # Handling the common values between ccl_prefill and ccl_decode. The elements of these two lists should be unique (COMPILER) - tmp_prefill = ccl_prefill - ccl_prefill = [] - for val in tmp_prefill: - while val in ccl_decode or val in ccl_prefill: - val -= 1 - if val < 0: - break # Prevent negative values - if val >= 0: - ccl_prefill.append(val) - ccl_prefill.sort() + # This cheking is related to disaggregated serving application since it generates two separate QPCs for prefilling and decoding. So, ccl_prefill will be None in decode QPC and ccl_decode will be None in prefill QPC + if ccl_prefill and ccl_decode: + # Handling the common values between ccl_prefill and ccl_decode. The elements of these two lists should be unique (COMPILER) + tmp_prefill = ccl_prefill + ccl_prefill = [] + for val in tmp_prefill: + while val in ccl_decode or val in ccl_prefill: + # In case of common values between ccl_prefill and ccl_decode, change the value in ccl_prefill and set it to the closest value which is multiple of CCL_UNIQNE_STEP to avoid repetition and also be hardware and compiler efficient + val = (val - 1) // constants.CCL_UNIQNE_STEP * constants.CCL_UNIQNE_STEP + if val < 0: + break # Prevent negative values + if val >= 0: + ccl_prefill.append(val) + ccl_prefill.sort() return ccl_prefill, ccl_decode -def process_ccl_specializations(ccl_prefill, ccl_decode, ctx_len, prefill_seq_len): +def process_ccl_specializations(ccl_prefill, ccl_decode, ctx_len, prefill_seq_len, enable_chunking=False): """ This function evaluates the values of CCL lists based on three inputs: - ccl_prefill: optional [list] @@ -193,7 +197,7 @@ def process_ccl_specializations(ccl_prefill, ccl_decode, ctx_len, prefill_seq_le # One of ccl lists is [] or None -> replace it with [ctx_len] -> CCL lists have to have a value when CCL is enabled # Condition #3, #4, #5, and #6 - elif not ccl_prefill or not ccl_decode: + elif not ccl_prefill or not ccl_decode and not enable_chunking: # Initial setting and will be checked with edge cases later ccl_prefill = ccl_prefill if ccl_prefill else [ctx_len] ccl_decode = ccl_decode if ccl_decode else [ctx_len] diff --git a/QEfficient/utils/constants.py b/QEfficient/utils/constants.py index 7e6dd1cbb..3991ec6cd 100644 --- a/QEfficient/utils/constants.py +++ b/QEfficient/utils/constants.py @@ -187,6 +187,7 @@ def get_models_dir(): CCL_MAX_ELEMENTS_LISTS = 5 CCL_START_CTX_LEN = 4096 CCL_MIN_CTX_LEN = 1024 +CCL_UNIQNE_STEP = 32 # used for gpt-oss prefill-only model Q-blocking GPT_OSS_PREFILL_Q_BLOCK_SIZE = 256