Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
76 changes: 62 additions & 14 deletions QEfficient/transformers/models/modeling_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -937,11 +937,13 @@ def __init__(
self.model = model
self.config = model.config

self.comp_ctx_lengths_prefill, self.comp_ctx_lengths_decode = process_ccl_specializations(qaic_config)

self.vision_model = QEffVisionEncoderForTextImageToTextModel(model, **kwargs)
self.lang_model = QEffCausalLMForTextImageToTextModel(model, qaic_config=qaic_config, **kwargs)
self.continuous_batching = continuous_batching
self.ccl_enabled = False
if qaic_config:
self.ccl_enabled = qaic_config.get("ccl_enabled", False)
self.comp_ctx_lengths_prefill, self.comp_ctx_lengths_decode = None, None
self.input_shapes, self.output_names = None, None

@property
Expand Down Expand Up @@ -985,6 +987,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: str, qaic_config: Option
logger.warning("Updating low_cpu_mem_usage=False")

kwargs.update({"attn_implementation": "eager", "low_cpu_mem_usage": False})

model = cls._hf_auto_class.from_pretrained(pretrained_model_name_or_path, **kwargs)
return cls(
model,
Expand Down Expand Up @@ -1095,6 +1098,8 @@ def compile(
compile_dir: Optional[str] = None,
*,
prefill_seq_len: Optional[int] = None,
comp_ctx_lengths_prefill: Optional[List[int]] = None,
comp_ctx_lengths_decode: Optional[List[int]] = None,
ctx_len: Optional[int] = None,
batch_size: int = 1,
full_batch_size: Optional[int] = None,
Expand Down Expand Up @@ -1179,10 +1184,21 @@ def compile(

output_names = self.model.get_output_names(kv_offload=True)

# if ccl_enabled is True read Compute-Context-Length lists
if self.ccl_enabled:
if comp_ctx_lengths_prefill is None or comp_ctx_lengths_decode is None:
logger.warning(
"Please set comp_ctx_lengths_prefill and comp_ctx_lengths_decode with a proper list of context lengths. Using non-CCL default model."
)
self.comp_ctx_lengths_prefill, self.comp_ctx_lengths_decode = process_ccl_specializations(
comp_ctx_lengths_prefill, comp_ctx_lengths_decode, ctx_len, prefill_seq_len
)

# For supporting VLLM and Disaggregated with CCL
if "comp_ctx_lengths_prefill" in compiler_options:
self.comp_ctx_lengths_prefill = compiler_options.pop("comp_ctx_lengths_prefill")
self.comp_ctx_lengths_decode = compiler_options.pop("comp_ctx_lengths_decode")
if comp_ctx_lengths_prefill is not None or comp_ctx_lengths_decode is not None:
self.comp_ctx_lengths_prefill, self.comp_ctx_lengths_decode = process_ccl_specializations(
comp_ctx_lengths_prefill, comp_ctx_lengths_decode, ctx_len, prefill_seq_len
)

specializations, compiler_options = self.model.get_specializations(
batch_size=batch_size,
Expand Down Expand Up @@ -1630,7 +1646,6 @@ def __init__(
raise NotImplementedError("Continuous batching is not supported for image-text-to-text models yet.")
super().__init__(model, **kwargs)

self.comp_ctx_lengths_prefill, self.comp_ctx_lengths_decode = process_ccl_specializations(qaic_config)
self.model.qaic_config = qaic_config

# to handle internvl models
Expand All @@ -1644,6 +1659,10 @@ def __init__(
else:
self.model.config.use_cache = True
self.hash_params["qeff_auto_class"] = self.__class__.__name__
self.ccl_enabled = False
if qaic_config:
self.ccl_enabled = qaic_config.get("ccl_enabled", False)
self.comp_ctx_lengths_prefill, self.comp_ctx_lengths_decode = None, None

if self.model.qaic_config is not None and self.model.qaic_config.get("num_kv_blocks", None) is not None:
BlockedKVAttentionTransform.apply(self.model, num_kv_blocks=self.model.qaic_config.get("num_kv_blocks"))
Expand Down Expand Up @@ -1683,6 +1702,7 @@ def from_pretrained(
logger.warning("Updating low_cpu_mem_usage=False")

kwargs.update({"attn_implementation": "eager", "low_cpu_mem_usage": False})

from transformers import AutoConfig

config = AutoConfig.from_pretrained(pretrained_model_name_or_path, trust_remote_code=True)
Expand Down Expand Up @@ -1737,6 +1757,8 @@ def compile(
*,
prefill_seq_len: Optional[int] = None,
ctx_len: Optional[int] = None,
comp_ctx_lengths_prefill: Optional[List[int]] = None,
comp_ctx_lengths_decode: Optional[List[int]] = None,
batch_size: int = 1,
full_batch_size: Optional[int] = None,
kv_cache_batch_size: Optional[int] = None,
Expand Down Expand Up @@ -1806,10 +1828,21 @@ def compile(
kv_cache_batch_size = kv_cache_batch_size or full_batch_size or batch_size
output_names = self.model.get_output_names()

# if ccl_enabled is True read Compute-Context-Length lists
if self.ccl_enabled:
if comp_ctx_lengths_prefill is None or comp_ctx_lengths_decode is None:
logger.warning(
"Please set comp_ctx_lengths_prefill and comp_ctx_lengths_decode with a proper list of context lengths. Using non-CCL default model."
)
self.comp_ctx_lengths_prefill, self.comp_ctx_lengths_decode = process_ccl_specializations(
comp_ctx_lengths_prefill, comp_ctx_lengths_decode, ctx_len, prefill_seq_len
)

# For supporting VLLM and Disaggregated with CCL
if "comp_ctx_lengths_prefill" in compiler_options:
self.comp_ctx_lengths_prefill = compiler_options.pop("comp_ctx_lengths_prefill")
self.comp_ctx_lengths_decode = compiler_options.pop("comp_ctx_lengths_decode")
if comp_ctx_lengths_prefill is not None or comp_ctx_lengths_decode is not None:
self.comp_ctx_lengths_prefill, self.comp_ctx_lengths_decode = process_ccl_specializations(
comp_ctx_lengths_prefill, comp_ctx_lengths_decode, ctx_len, prefill_seq_len
)

# Get specializations from modelling file
# TODO: expose this via the auto class as well
Expand Down Expand Up @@ -2374,8 +2407,6 @@ def __init__(
# Set use_cache=True to get KV values as output during ONNX export
model.config.use_cache = True

self.comp_ctx_lengths_prefill, self.comp_ctx_lengths_decode = process_ccl_specializations(qaic_config)

super().__init__(model, qaic_config=qaic_config, **kwargs)
self.num_layers = model.config.num_hidden_layers
self.continuous_batching = continuous_batching
Expand All @@ -2384,6 +2415,10 @@ def __init__(
self.is_tlm = transformed

self.hash_params["qeff_auto_class"] = self.__class__.__name__
self.ccl_enabled = False
if qaic_config:
self.ccl_enabled = qaic_config.get("ccl_enabled", False)
self.comp_ctx_lengths_prefill, self.comp_ctx_lengths_decode = None, None

# ---Sampling---
# Note: SamplerTransform should be applied after all other transforms
Expand Down Expand Up @@ -2828,6 +2863,8 @@ def compile(
*,
prefill_seq_len: int = 32,
ctx_len: int = 128,
comp_ctx_lengths_prefill: Optional[List[int]] = None,
comp_ctx_lengths_decode: Optional[List[int]] = None,
batch_size: int = 1,
full_batch_size: Optional[int] = None,
kv_cache_batch_size: Optional[int] = None,
Expand Down Expand Up @@ -2919,10 +2956,18 @@ def compile(

"""

# if ccl_enabled is True read Compute-Context-Length lists
if self.ccl_enabled:
if comp_ctx_lengths_prefill is None or comp_ctx_lengths_decode is None:
logger.warning(
"Please set comp_ctx_lengths_prefill and comp_ctx_lengths_decode with a proper list of context lengths. Using non-CCL default model."
)
self.comp_ctx_lengths_prefill, self.comp_ctx_lengths_decode = process_ccl_specializations(
comp_ctx_lengths_prefill, comp_ctx_lengths_decode, ctx_len, prefill_seq_len
)

# For supporting VLLM and Disaggregated with CCL
if "comp_ctx_lengths_prefill" in compiler_options and "comp_ctx_lengths_decode" in compiler_options:
comp_ctx_lengths_prefill = compiler_options.pop("comp_ctx_lengths_prefill")
comp_ctx_lengths_decode = compiler_options.pop("comp_ctx_lengths_decode")
if comp_ctx_lengths_prefill is not None or comp_ctx_lengths_decode is not None:
if isinstance(comp_ctx_lengths_prefill, str):
import ast

Expand All @@ -2937,6 +2982,9 @@ def compile(
self.comp_ctx_lengths_prefill = comp_ctx_lengths_prefill
self.comp_ctx_lengths_decode = comp_ctx_lengths_decode

self.comp_ctx_lengths_prefill, self.comp_ctx_lengths_decode = process_ccl_specializations(
self.comp_ctx_lengths_prefill, self.comp_ctx_lengths_decode, ctx_len, prefill_seq_len
)
# --- Validation ---
if prefill_only is not None and not isinstance(prefill_only, bool):
raise TypeError("`prefill_only` must be a boolean.")
Expand Down
2 changes: 2 additions & 0 deletions QEfficient/transformers/spd/spd_transform_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ def tlm_forward(
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
comp_ctx_lengths: Optional[torch.LongTensor] = None,
batch_index: Optional[torch.LongTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
Expand Down Expand Up @@ -123,6 +124,7 @@ def tlm_forward(
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
comp_ctx_lengths=comp_ctx_lengths,
batch_index=batch_index,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
Expand Down
9 changes: 1 addition & 8 deletions QEfficient/utils/check_ccl_specializations.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,7 @@
# -----------------------------------------------------------------------------


def process_ccl_specializations(qaic_config):
if qaic_config is None:
return None, None
ccl_prefill = qaic_config.pop("comp_ctx_lengths_prefill", None)
ccl_decode = qaic_config.pop("comp_ctx_lengths_decode", None)
ctx_len = qaic_config.pop("ctx_len", None)
prefill_seq_len = qaic_config.pop("prefill_seq_len", 128)

def process_ccl_specializations(ccl_prefill, ccl_decode, ctx_len, prefill_seq_len):
if ccl_prefill is None or ccl_decode is None:
return None, None

Expand Down
50 changes: 50 additions & 0 deletions examples/performance/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,56 @@ python on_device_sampling.py \
--top-p 0.89
```

### Compute-Context-Length

Calculating Context-Length dynamically during inference for getting the best related performance within each window of context-length

#### compute_context_length/basic_inference.py
Configure CCL parameters: 1) ccl-enabled: to activate CCL feature, 2) comp-ctx-lengths-prefill: list of context length to be used during prefilling, and 3) comp-ctx-lengths-decode: list of context lengths to be used during decoding.

**Usage for Text-only models:**
```bash
python compute_context_length/basic_inference.py \
--model-name meta-llama/Llama-3.1-8B \
--num-cores 16 \
--prefill-seq-len 32 \
--ctx-len 1024 \
--ccl-enabled \
--comp-ctx-lengths-prefill 500,1000 \
--comp-ctx-lengths-decode 512,1024
```

**Usage for VLM models such as mllama and llava:**
```bash
python compute_context_length/vlm_inference.py \
--model-name meta-llama/Llama-3.2-11B-Vision-Instruct \
--hf-token "" \
--num-cores 16 \
--prefill-seq-len 32 \
--ctx-len 8192 \
--img-size 560 \
--ccl-enabled \
--comp-ctx-lengths-prefill 4096 \
--comp-ctx-lengths-decode 6144,8192
```

**Usage with other MoE and Multimodal models:**
For various models available in compute_context_length directory such as gemma3, gpt_oss, granite_vision, internvl, llama4_cb, llama4_multi_image, llama4, mistral3, molmo, qwen2_5_vl, qwen2_5_vl_cb, and qwen3moe, use the related inference script and only change the model-name and ccl configuration in the related script. The following is an example of each model:
```bash
python compute_context_length/gemma3.py
python compute_context_length/gpt_oss.py
python compute_context_length/granite_vision.py
python compute_context_length/internvl.py
python compute_context_length/llama4_cb.py
python compute_context_length/llama4_multi_image.py
python compute_context_length/llama4.py
python compute_context_length/mistral3.py
python compute_context_length/molmo.py
python compute_context_length/qwen2_5_vl.py
python compute_context_length/qwen2_5_vl_cb.py
python compute_context_length/qwen3moe.py
```

## Performance Tips

1. **Speculative Decoding**: Best for long-form generation where draft model is much faster than target
Expand Down
5 changes: 4 additions & 1 deletion examples/performance/compute_context_length/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ python vlm_inference.py \
Basic CCL usage with text-only language models.

**Supported Models:**
- Llama (3.2, 3.3)
- Llama (3.2, 3.3, swiftkv)
- Gemma/Gemma-2
- Mistral
- Phi/Phi-3
Expand All @@ -77,6 +77,9 @@ Basic CCL usage with text-only language models.
- GPT-2, GPT-J
- CodeGen
- OLMo-2
- Mistral/Mixtral
- Qwen2
- Falcon

**Command-Line Arguments:**
- `--model-name`: HuggingFace model ID (default: meta-llama/Llama-3.2-1B)
Expand Down
12 changes: 9 additions & 3 deletions examples/performance/compute_context_length/basic_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,11 @@ def main():
default=1024,
help="Maximum context length",
)
parser.add_argument(
"--ccl-enabled",
action="store_true",
help="Enable compute-context-length (CCL) feature",
)
parser.add_argument(
"--comp-ctx-lengths-prefill",
type=lambda x: [int(i) for i in x.split(",")],
Expand Down Expand Up @@ -113,9 +118,7 @@ def main():
args.model_name,
continuous_batching=args.continuous_batching,
qaic_config={
"comp_ctx_lengths_prefill": args.comp_ctx_lengths_prefill,
"comp_ctx_lengths_decode": args.comp_ctx_lengths_decode,
"ctx_len": args.ctx_len, # Required for CCL validation
"ccl_enabled": args.ccl_enabled,
},
)

Expand All @@ -132,6 +135,9 @@ def main():

if args.continuous_batching:
compile_kwargs["full_batch_size"] = args.full_batch_size
if args.ccl_enabled:
compile_kwargs["comp_ctx_lengths_prefill"] = args.comp_ctx_lengths_prefill
compile_kwargs["comp_ctx_lengths_decode"] = args.comp_ctx_lengths_decode

qpc_path = model.compile(**compile_kwargs)
print(f"Model compiled successfully to: {qpc_path}")
Expand Down
Loading