Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions QEfficient/transformers/models/modeling_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -3288,7 +3288,7 @@ def compile(
if comp_ctx_lengths_prefill is None and comp_ctx_lengths_decode is None:
logger.info("Auto-generating CCL-prefill and CCL-decode lists based on Context Length (CL).")
self.comp_ctx_lengths_prefill, self.comp_ctx_lengths_decode, ctx_len = process_ccl_specializations(
comp_ctx_lengths_prefill, comp_ctx_lengths_decode, ctx_len, prefill_seq_len
comp_ctx_lengths_prefill, comp_ctx_lengths_decode, ctx_len, prefill_seq_len, enable_chunking
)
# For supporting VLLM and Disaggregated with CCL
elif comp_ctx_lengths_prefill is not None or comp_ctx_lengths_decode is not None:
Expand All @@ -3308,7 +3308,7 @@ def compile(
self.comp_ctx_lengths_decode = comp_ctx_lengths_decode

self.comp_ctx_lengths_prefill, self.comp_ctx_lengths_decode, ctx_len = process_ccl_specializations(
self.comp_ctx_lengths_prefill, self.comp_ctx_lengths_decode, ctx_len, prefill_seq_len
self.comp_ctx_lengths_prefill, self.comp_ctx_lengths_decode, ctx_len, prefill_seq_len, enable_chunking
)
# --- Validation ---
if prefill_only is not None and not isinstance(prefill_only, bool):
Expand All @@ -3333,8 +3333,8 @@ def compile(
ccl_lengths = self.comp_ctx_lengths_decode if prefill_seq_len == 1 else self.comp_ctx_lengths_prefill
# Adding elements from self.comp_ctx_lengths_prefill to prefill_specialization
for i in range(0, len(ccl_lengths)):
if prefill_only or enable_chunking:
raise NotImplementedError("prefill_only or enable_chunking is not supported with CCL")
# if prefill_only or enable_chunking:
# raise NotImplementedError("prefill_only or enable_chunking is not supported with CCL")
specializations.append(
self.build_prefill_specialization(
prefill_seq_len=prefill_seq_len,
Expand Down
42 changes: 23 additions & 19 deletions QEfficient/utils/check_ccl_specializations.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,14 +132,15 @@ def validate_ccl_lists(ccl_prefill, ccl_decode, ctx_len, prefill_seq_len):
# Check CCL values are not negative and more than the CCL minimum context length = constants.CCL_MIN_CTX_LEN
if ccl_prefill:
ccl_prefill = [x if x >= constants.CCL_MIN_CTX_LEN else constants.CCL_MIN_CTX_LEN for x in ccl_prefill]
# Check the last element of ccl_prefill and ccl_decode to make sure it's not less than ctx_len
if ccl_prefill[-1] < ctx_len:
ccl_prefill.append(ctx_len)

if ccl_decode:
ccl_decode = [x if x >= constants.CCL_MIN_CTX_LEN else constants.CCL_MIN_CTX_LEN for x in ccl_decode]

# Check the last element of ccl_prefill and ccl_decode to make sure it's not less than ctx_len
if ccl_prefill[-1] < ctx_len - 1:
ccl_prefill.append(ctx_len)
if ccl_decode[-1] < ctx_len:
ccl_decode.append(ctx_len)
# Check the last element of ccl_prefill and ccl_decode to make sure it's not less than ctx_len
if ccl_decode[-1] < ctx_len:
ccl_decode.append(ctx_len)

if prefill_seq_len == 1:
# both prefill and decode ccl can share the same specializations since prefill_seq_len=1. So, a sorted union of both lists can be used for both of them.
Expand All @@ -153,22 +154,25 @@ def validate_ccl_lists(ccl_prefill, ccl_decode, ctx_len, prefill_seq_len):
if ccl_decode:
ccl_decode = sorted({min(x, ctx_len) for x in (ccl_decode)})

# Handling the common values between ccl_prefill and ccl_decode. The elements of these two lists should be unique (COMPILER)
tmp_prefill = ccl_prefill
ccl_prefill = []
for val in tmp_prefill:
while val in ccl_decode or val in ccl_prefill:
val -= 1
if val < 0:
break # Prevent negative values
if val >= 0:
ccl_prefill.append(val)
ccl_prefill.sort()
# This cheking is related to disaggregated serving application since it generates two separate QPCs for prefilling and decoding. So, ccl_prefill will be None in decode QPC and ccl_decode will be None in prefill QPC
if ccl_prefill and ccl_decode:
# Handling the common values between ccl_prefill and ccl_decode. The elements of these two lists should be unique (COMPILER)
tmp_prefill = ccl_prefill
ccl_prefill = []
for val in tmp_prefill:
while val in ccl_decode or val in ccl_prefill:
# In case of common values between ccl_prefill and ccl_decode, change the value in ccl_prefill and set it to the closest value which is multiple of CCL_UNIQNE_STEP to avoid repetition and also be hardware and compiler efficient
val = (val - 1) // constants.CCL_UNIQNE_STEP * constants.CCL_UNIQNE_STEP
if val < 0:
break # Prevent negative values
if val >= 0:
ccl_prefill.append(val)
ccl_prefill.sort()

return ccl_prefill, ccl_decode


def process_ccl_specializations(ccl_prefill, ccl_decode, ctx_len, prefill_seq_len):
def process_ccl_specializations(ccl_prefill, ccl_decode, ctx_len, prefill_seq_len, enable_chunking=False):
"""
This function evaluates the values of CCL lists based on three inputs:
- ccl_prefill: optional [list]
Expand All @@ -193,7 +197,7 @@ def process_ccl_specializations(ccl_prefill, ccl_decode, ctx_len, prefill_seq_le

# One of ccl lists is [] or None -> replace it with [ctx_len] -> CCL lists have to have a value when CCL is enabled
# Condition #3, #4, #5, and #6
elif not ccl_prefill or not ccl_decode:
elif not ccl_prefill or not ccl_decode and not enable_chunking:
# Initial setting and will be checked with edge cases later
ccl_prefill = ccl_prefill if ccl_prefill else [ctx_len]
ccl_decode = ccl_decode if ccl_decode else [ctx_len]
Expand Down
1 change: 1 addition & 0 deletions QEfficient/utils/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,7 @@ def get_models_dir():
CCL_MAX_ELEMENTS_LISTS = 5
CCL_START_CTX_LEN = 4096
CCL_MIN_CTX_LEN = 1024
CCL_UNIQNE_STEP = 32

# used for gpt-oss prefill-only model Q-blocking
GPT_OSS_PREFILL_Q_BLOCK_SIZE = 256
Expand Down