Skip to content

Commit d02d04d

Browse files
committed
Merge remote-tracking branch 'origin/main' into HEAD
2 parents 30d6061 + f4ff803 commit d02d04d

File tree

61 files changed

+3195
-314
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

61 files changed

+3195
-314
lines changed

QEfficient/cloud/infer.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -340,6 +340,18 @@ def main(
340340
"--prompt-len", "--prompt_len", default=32, type=int, help="Sequence length for text generation."
341341
)
342342
parser.add_argument("--ctx-len", "--ctx_len", default=128, type=int, help="Context length for text generation.")
343+
parser.add_argument(
344+
"--comp-ctx-lengths-prefill",
345+
type=lambda comp_ctx_lengths_prefill: [int(x) for x in comp_ctx_lengths_prefill.split(",")],
346+
default=[512],
347+
help="Define ccl list in csv format (e.g.,--comp-ctx-lengths 512,1024,2048).",
348+
)
349+
parser.add_argument(
350+
"--comp-ctx-lengths-decode",
351+
type=lambda comp_ctx_lengths_decode: [int(x) for x in comp_ctx_lengths_decode.split(",")],
352+
default=[2048],
353+
help="Define ccl list in csv format (e.g.,--comp-ctx-lengths 512,1024,2048).",
354+
)
343355
parser.add_argument(
344356
"--mxfp6",
345357
"--mxfp6_matmul",

QEfficient/customop/ctx_scatter_gather.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -115,8 +115,14 @@ def symbolic(g: torch.Graph, data: torch.Value, ctx_indices: torch.Value) -> tor
115115

116116

117117
@onnxscript.script(onnxscript.values.Opset("com.qualcomm.cloud", 1))
118-
def CtxGather(data: onnxscript.FLOAT, ctx_indices: onnxscript.INT32) -> onnxscript.FLOAT:
119-
ctx_indices = ops.Expand(ctx_indices, ops.Slice(ops.Shape(data), starts=[0], ends=[3], axes=[0]))
118+
def CtxGather(
119+
data: onnxscript.FLOAT, ctx_indices: onnxscript.INT32, comp_ctx_len: onnxscript.INT32
120+
) -> onnxscript.FLOAT:
121+
# Create a shape tensor based on comp_ctx_len
122+
shape_tensor = ops.Concat(ops.Shape(data)[:2], ops.Reshape(comp_ctx_len, [1]), axis=0)
123+
124+
# Directly use the shape tensor without validation
125+
ctx_indices = ops.Expand(ctx_indices, shape_tensor)
120126
ctx_indices = ops.Unsqueeze(ctx_indices, [-1])
121127
return ops.GatherND(data, ctx_indices, batch_dims=2)
122128

@@ -127,7 +133,7 @@ class CtxGatherFunc(torch.autograd.Function):
127133
"""
128134

129135
@staticmethod
130-
def forward(data: torch.Tensor, ctx_indices: torch.Tensor):
136+
def forward(data: torch.Tensor, ctx_indices: torch.Tensor, comp_ctx_len: int):
131137
batch_indices = torch.arange(data.shape[0]).view(-1, 1, 1)
132138
head_indices = torch.arange(data.shape[1]).view(1, -1, 1)
133139
return data[batch_indices, head_indices, ctx_indices]
@@ -137,5 +143,5 @@ def setup_context(ctx, inputs, outputs):
137143
pass
138144

139145
@staticmethod
140-
def symbolic(g: torch.Graph, data: torch.Value, ctx_indices: torch.Value) -> torch.Value:
141-
return g.onnxscript_op(CtxGather, data, ctx_indices).setTypeAs(data)
146+
def symbolic(g: torch.Graph, data: torch.Value, ctx_indices: torch.Value, comp_ctx_len: int) -> torch.Value:
147+
return g.onnxscript_op(CtxGather, data, ctx_indices, comp_ctx_len).setTypeAs(data)

QEfficient/customop/ctx_scatter_gather_cb.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -97,16 +97,20 @@ def symbolic(
9797

9898
@onnxscript.script(onnxscript.values.Opset("com.qualcomm.cloud", 1))
9999
def CtxGatherCB(
100-
data: onnxscript.FLOAT, batch_index: onnxscript.INT32, ctx_indices: onnxscript.INT32
100+
data: onnxscript.FLOAT, batch_index: onnxscript.INT32, ctx_indices: onnxscript.INT32, comp_ctx_len: onnxscript.INT32
101101
) -> onnxscript.FLOAT:
102102
batch_size = ops.Gather(ops.Shape(batch_index), [0])
103103
num_heads = ops.Gather(ops.Shape(data), [1])
104-
ctx_len = ops.Gather(ops.Shape(data), [2])
104+
# using compute-context-length (CCL) instead of context-length to do gather process based on CCL and later do attention computations based on CCL as well.
105+
ctx_len = ops.Reshape(comp_ctx_len, [1])
105106

106107
# Expanded shape to create indices
107108
zero = ops.Constant(value_ints=[0])
108109
one = ops.Constant(value_ints=[1])
109-
exp_shape = ops.Concat(batch_size, num_heads, ctx_len, one, axis=0)
110+
# exp_shape = ops.Concat(batch_size, num_heads, ctx_len, one, axis=0)
111+
exp_shape = ops.Concat(
112+
ops.Reshape(batch_size, [1]), ops.Reshape(num_heads, [1]), ops.Reshape(ctx_len, [1]), one, axis=0
113+
)
110114

111115
# Create indices
112116
batch_idx = ops.Expand(ops.Unsqueeze(batch_index, [2, 3]), exp_shape)
@@ -119,7 +123,7 @@ def CtxGatherCB(
119123

120124
class CtxGatherFuncCB(torch.autograd.Function):
121125
@staticmethod
122-
def forward(data: torch.Tensor, batch_index: torch.Tensor, ctx_indices: torch.Tensor):
126+
def forward(data: torch.Tensor, batch_index: torch.Tensor, ctx_indices: torch.Tensor, comp_ctx_len: int):
123127
batch_indices = batch_index.view(-1, 1, 1)
124128
head_indices = torch.arange(data.shape[1]).view(1, -1, 1)
125129
return data[batch_indices, head_indices, ctx_indices]
@@ -129,8 +133,10 @@ def setup_context(ctx, inputs, outputs):
129133
pass
130134

131135
@staticmethod
132-
def symbolic(g: torch.Graph, data: torch.Value, batch_index: torch.Value, ctx_indices: torch.Value) -> torch.Value:
133-
return g.onnxscript_op(CtxGatherCB, data, batch_index, ctx_indices).setTypeAs(data)
136+
def symbolic(
137+
g: torch.Graph, data: torch.Value, batch_index: torch.Value, ctx_indices: torch.Value, comp_ctx_len: int
138+
) -> torch.Value:
139+
return g.onnxscript_op(CtxGatherCB, data, batch_index, ctx_indices, comp_ctx_len).setTypeAs(data)
134140

135141

136142
@onnxscript.script(onnxscript.values.Opset("com.qualcomm.cloud", 1))

QEfficient/exporter/export_utils.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from onnx import external_data_helper
1919

2020
from QEfficient.base.onnx_transforms import FP16ClipTransform
21+
from QEfficient.utils import constants
2122

2223

2324
def export_onnx(
@@ -97,7 +98,7 @@ def export_onnx(
9798
input_names=input_names,
9899
output_names=output_names,
99100
dynamic_axes=dynamic_axes,
100-
opset_version=13,
101+
opset_version=constants.ONNX_EXPORT_OPSET,
101102
custom_opsets={"com.qti.aisw.onnx": 1},
102103
)
103104
except Exception as e:

QEfficient/finetune/utils/train_utils.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -123,11 +123,12 @@ def train(
123123
break
124124

125125
if train_config.use_peft and train_config.from_peft_checkpoint:
126+
path = train_config.from_peft_checkpoint.rstrip("/")
126127
try:
127-
intermediate_epoch = int(train_config.from_peft_checkpoint.split("/")[-2].split("_")[-1]) - 1
128-
intermediate_step = int(train_config.from_peft_checkpoint.split("/")[-1].split("_")[-1])
128+
intermediate_epoch = int(path.split("/")[-2].split("_")[-1]) - 1
129+
intermediate_step = int(path.split("/")[-1].split("_")[-1])
129130
except (IndexError, ValueError):
130-
intermediate_epoch = int(train_config.from_peft_checkpoint.split("/")[-1].split("_")[-1]) - 1
131+
intermediate_epoch = int(path.split("/")[-1].split("_")[-1]) - 1
131132
intermediate_step = 0
132133

133134
if epoch < intermediate_epoch:
@@ -374,7 +375,7 @@ def train(
374375
eval_step_metric,
375376
eval_metric,
376377
)
377-
avg_epoch_time = sum(epoch_times) / len(epoch_times)
378+
avg_epoch_time = sum(epoch_times) / len(epoch_times) if len(epoch_times) > 0 else 0
378379
avg_checkpoint_time = sum(checkpoint_times) / len(checkpoint_times) if len(checkpoint_times) > 0 else 0
379380

380381
results["last_epoch_train_loss"] = train_epoch_loss.cpu()

QEfficient/generation/text_generation_inference.py

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -318,6 +318,8 @@ def cloud_ai_100_exec_kv(
318318
prompts_txt_file_path: Optional[str] = None,
319319
device_id: Optional[List[int]] = None,
320320
generation_len: Optional[int] = None,
321+
comp_ctx_lengths_prefill: Optional[List[int]] = None,
322+
comp_ctx_lengths_decode: Optional[List[int]] = None,
321323
enable_debug_logs: bool = False,
322324
stream: bool = True,
323325
write_io_dir: Optional[str] = None,
@@ -384,6 +386,8 @@ def cloud_ai_100_exec_kv(
384386
qpc_path=qpc_path,
385387
device_id=device_id,
386388
ctx_len=ctx_len,
389+
comp_ctx_lengths_prefill=comp_ctx_lengths_prefill,
390+
comp_ctx_lengths_decode=comp_ctx_lengths_decode,
387391
enable_debug_logs=enable_debug_logs,
388392
write_io_dir=write_io_dir,
389393
full_batch_size=full_batch_size,
@@ -430,6 +434,8 @@ def __init__(
430434
qpc_path: str,
431435
full_batch_size: Optional[int] = None,
432436
ctx_len: Optional[int] = None,
437+
comp_ctx_lengths_prefill: Optional[List[int]] = None,
438+
comp_ctx_lengths_decode: Optional[List[int]] = None,
433439
device_id: Optional[List[int]] = None,
434440
enable_debug_logs: bool = False,
435441
write_io_dir: Optional[str] = None,
@@ -440,6 +446,8 @@ def __init__(
440446
activate: bool = True,
441447
) -> None:
442448
self._ctx_len = ctx_len
449+
self.comp_ctx_lengths_prefill = comp_ctx_lengths_prefill
450+
self.comp_ctx_lengths_decode = comp_ctx_lengths_decode
443451
self._write_io_dir = write_io_dir
444452
self.is_tlm = is_tlm
445453
self.return_pdfs = return_pdfs
@@ -802,7 +810,17 @@ def run_prefill(self, prompt, generation_len, prefill_logit_bs=1, decode_batch_i
802810
batch_lora_ids = [self._prompt_to_lora_id_mapping_prefill.popleft() for i in range(self.batch_size)]
803811
inputs["lora_ids"] = np.array(batch_lora_ids, dtype=np.int64).reshape(self.batch_size, 1)
804812

813+
if self.comp_ctx_lengths_prefill is not None:
814+
self.list_of_comp_ctx_lengths_prefill = [np.zeros(length) for length in self.comp_ctx_lengths_prefill]
815+
prefill_ccl_id = 0
816+
inputs["comp_ctx_lengths"] = self.list_of_comp_ctx_lengths_prefill[prefill_ccl_id]
817+
805818
for i in range(num_chunks):
819+
if self.comp_ctx_lengths_prefill is not None:
820+
if (i + 1) * self._prefill_seq_len > self.comp_ctx_lengths_prefill[prefill_ccl_id]:
821+
prefill_ccl_id = min(prefill_ccl_id + 1, len(self.comp_ctx_lengths_prefill) - 1)
822+
inputs["comp_ctx_lengths"] = self.list_of_comp_ctx_lengths_prefill[prefill_ccl_id]
823+
806824
chunk_inputs = inputs.copy()
807825
chunk_inputs["input_ids"] = inputs["input_ids"][
808826
:, i * self._prefill_seq_len : (i + 1) * self._prefill_seq_len
@@ -822,6 +840,19 @@ def run_prefill(self, prompt, generation_len, prefill_logit_bs=1, decode_batch_i
822840
generation_len,
823841
)
824842

843+
def initialize_ccl(self, decode_inputs):
844+
self.list_of_comp_ctx_lengths_decode = [np.zeros(length) for length in self.comp_ctx_lengths_decode]
845+
max_ccl_id = len(self.comp_ctx_lengths_decode) - 1
846+
max_position_id = np.max(decode_inputs["position_ids"])
847+
ccl_id_initial = 0
848+
ccl_id = ccl_id_initial
849+
for i in range(ccl_id_initial, len(self.comp_ctx_lengths_decode)):
850+
if max_position_id < self.comp_ctx_lengths_decode[i]:
851+
ccl_id = i
852+
break
853+
854+
return ccl_id, max_ccl_id
855+
825856
def run_continuous_batching_decode(self, prompt_queue, generation_len):
826857
"""
827858
Runs continuous batching decode for the given prompt queue and generation length.
@@ -853,6 +884,10 @@ def run_continuous_batching_decode(self, prompt_queue, generation_len):
853884
# Prepare decode inputs inputs.
854885
decode_inputs = self.prepare_decode_inputs()
855886

887+
if self.comp_ctx_lengths_decode is not None:
888+
ccl_id, max_ccl_id = self.initialize_ccl(decode_inputs)
889+
decode_inputs["comp_ctx_lengths"] = self.list_of_comp_ctx_lengths_decode[ccl_id]
890+
856891
while prompt_queue or current_decode_ongoing.any():
857892
outputs = self._session.run(decode_inputs)
858893

@@ -890,6 +925,20 @@ def run_continuous_batching_decode(self, prompt_queue, generation_len):
890925
batch_id_map[decode_batch_id]
891926
]
892927

928+
if self.comp_ctx_lengths_decode is not None:
929+
###Recalculate ccl_id based on position ids###
930+
# Determine the maximum value of position_ids across all batch elements
931+
max_position_id = np.max(decode_inputs["position_ids"])
932+
933+
# Update ccl_id and comp_ctx_lengths_decode based on the maximum position id
934+
ccl_id_initial = 0
935+
ccl_id = ccl_id_initial
936+
for i in range(ccl_id_initial, len(self.comp_ctx_lengths_decode)):
937+
if max_position_id < self.comp_ctx_lengths_decode[i]:
938+
ccl_id = i
939+
break
940+
decode_inputs["comp_ctx_lengths"] = self.list_of_comp_ctx_lengths_decode[ccl_id]
941+
893942
else:
894943
current_decode_ongoing[decode_batch_id] = False
895944
else:
@@ -902,6 +951,15 @@ def run_continuous_batching_decode(self, prompt_queue, generation_len):
902951
if self.include_sampler:
903952
decode_inputs["last_accepted_output_tokens"] = decode_inputs["input_ids"]
904953

954+
if self.comp_ctx_lengths_decode is not None:
955+
# Update ccl_id and comp_ctx_lengths_decode based on the maximum position id
956+
if (
957+
decode_inputs["position_ids"][decode_batch_id, -1]
958+
>= self.comp_ctx_lengths_decode[ccl_id] - 1
959+
):
960+
ccl_id = min(ccl_id + 1, max_ccl_id)
961+
decode_inputs["comp_ctx_lengths"] = self.list_of_comp_ctx_lengths_decode[ccl_id]
962+
905963
generated_id_current_index[decode_batch_id] += 1
906964

907965
return decode_pause_time
@@ -928,7 +986,18 @@ def run_decode(
928986
self._session.set_buffers({"logits": logits_out_placeholder})
929987
finished_sequences = decode_inputs["input_ids"] == self.tokenizer.eos_token_id
930988
num_token = 0
989+
990+
if self.comp_ctx_lengths_decode is not None:
991+
ccl_id, max_ccl_id = self.initialize_ccl(decode_inputs)
992+
decode_inputs["comp_ctx_lengths"] = self.list_of_comp_ctx_lengths_decode[ccl_id]
993+
994+
cache_index = np.max(decode_inputs["position_ids"])
931995
for num_token in range(1, generation_len):
996+
if self.comp_ctx_lengths_decode is not None:
997+
if cache_index >= self.comp_ctx_lengths_decode[ccl_id] - 1:
998+
ccl_id = min(ccl_id + 1, max_ccl_id)
999+
decode_inputs["comp_ctx_lengths"] = self.list_of_comp_ctx_lengths_decode[ccl_id]
1000+
9321001
if streamer:
9331002
streamer.put(decode_inputs["input_ids"][0])
9341003
outputs = self._session.run(decode_inputs)
@@ -940,6 +1009,7 @@ def run_decode(
9401009
# Prepare inputs for next iteration
9411010
decode_inputs["input_ids"] = self._fetch_next_token_id(outputs)
9421011
decode_inputs["position_ids"][:, -1] += 1
1012+
cache_index += 1
9431013
self.generated_ids[:, num_token] = decode_inputs["input_ids"][:, -1]
9441014
finished_sequences |= decode_inputs["input_ids"] == self.tokenizer.eos_token_id
9451015
if self.include_sampler:
@@ -989,6 +1059,8 @@ def __init__(
9891059
qpc_path: str,
9901060
full_batch_size: Optional[int] = None,
9911061
ctx_len: Optional[int] = None,
1062+
comp_ctx_lengths_prefill: Optional[List[int]] = None,
1063+
comp_ctx_lengths_decode: Optional[List[int]] = None,
9921064
device_id: Optional[List[int]] = None,
9931065
enable_debug_logs: bool = False,
9941066
write_io_dir: Optional[str] = None,
@@ -1002,6 +1074,8 @@ def __init__(
10021074
qpc_path=qpc_path,
10031075
full_batch_size=full_batch_size,
10041076
ctx_len=ctx_len,
1077+
comp_ctx_lengths_prefill=comp_ctx_lengths_prefill,
1078+
comp_ctx_lengths_decode=comp_ctx_lengths_decode,
10051079
device_id=device_id,
10061080
enable_debug_logs=enable_debug_logs,
10071081
write_io_dir=write_io_dir,
@@ -1013,6 +1087,8 @@ def __init__(
10131087
self._full_batch_size = self._qaic_model.full_batch_size
10141088
self._tokenizer = self._qaic_model.tokenizer
10151089
self._ctx_len = ctx_len
1090+
self.comp_ctx_lengths_prefill = comp_ctx_lengths_prefill
1091+
self.comp_ctx_lengths_decode = comp_ctx_lengths_decode
10161092
self._perf_metrics = None
10171093
self._prompt_queue = None
10181094
self._text_streamer = None

QEfficient/generation/vlm_generation.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,8 @@ def __init__(
8383
vision_qpc_path: str,
8484
device_id: Optional[List[int]] = None,
8585
ctx_len: Optional[int] = None,
86+
comp_ctx_lengths_prefill: Optional[List[int]] = None,
87+
comp_ctx_lengths_decode: Optional[List[int]] = None,
8688
enable_debug_logs: bool = False,
8789
write_io_dir: Optional[str] = None,
8890
full_batch_size: Optional[int] = None,
@@ -123,6 +125,8 @@ def __init__(
123125
qpc_path=lang_qpc_path,
124126
full_batch_size=full_batch_size,
125127
ctx_len=ctx_len,
128+
comp_ctx_lengths_prefill=comp_ctx_lengths_prefill,
129+
comp_ctx_lengths_decode=comp_ctx_lengths_decode,
126130
device_id=device_id,
127131
enable_debug_logs=enable_debug_logs,
128132
write_io_dir=write_io_dir,
@@ -294,6 +298,11 @@ def _execute_chunked_prefill(
294298
outputs = None
295299
chunk_image_idx = None
296300

301+
if self.comp_ctx_lengths_prefill is not None:
302+
self.list_of_comp_ctx_lengths_prefill = [np.zeros(length) for length in self.comp_ctx_lengths_prefill]
303+
prefill_ccl_id = 0
304+
lang_inputs["comp_ctx_lengths"] = self.list_of_comp_ctx_lengths_prefill[prefill_ccl_id]
305+
297306
for i in range(num_chunks):
298307
input_ids_slice = lang_inputs["input_ids"][:, i * self._prefill_seq_len : (i + 1) * self._prefill_seq_len]
299308
position_ids_slice = lang_inputs["position_ids"][
@@ -312,6 +321,13 @@ def _execute_chunked_prefill(
312321
if "cross_attention_mask" in lang_inputs:
313322
chunk_inputs["cross_attention_mask"] = lang_inputs["cross_attention_mask"]
314323

324+
if self.comp_ctx_lengths_prefill is not None:
325+
if (i + 1) * self._prefill_seq_len > self.comp_ctx_lengths_prefill[prefill_ccl_id]:
326+
prefill_ccl_id = min(prefill_ccl_id + 1, len(self.comp_ctx_lengths_prefill) - 1)
327+
lang_inputs["comp_ctx_lengths"] = self.list_of_comp_ctx_lengths_prefill[prefill_ccl_id]
328+
329+
chunk_inputs["comp_ctx_lengths"] = lang_inputs["comp_ctx_lengths"]
330+
315331
outputs = self._session.run(chunk_inputs)
316332

317333
if "image_idx_output" in outputs:

QEfficient/peft/lora/layers.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -42,15 +42,15 @@ def forward(self, x: torch.Tensor, lora_ids: torch.Tensor):
4242
# multilora implementation: lora_ids <batch_size, 1>
4343
other_indices_a = torch.arange(self.lora_a_weights.shape[2]).view(1, 1, -1)
4444
selected_lora_a_weights = CtxGatherFuncCB.apply(
45-
self.lora_a_weights, lora_ids, other_indices_a
45+
self.lora_a_weights, lora_ids, other_indices_a, self.lora_a_weights.shape[2]
4646
) # <num_loras, 1, feature, r>
4747
other_indices_b = torch.arange(self.lora_b_weights.shape[2]).view(1, 1, -1)
4848
selected_lora_b_weights = CtxGatherFuncCB.apply(
49-
self.lora_b_weights, lora_ids, other_indices_b
49+
self.lora_b_weights, lora_ids, other_indices_b, self.lora_b_weights.shape[2]
5050
) # <num_loras, 1, r, feature>
5151
other_indices_s = torch.arange(self.lora_scalings.shape[2]).view(1, 1, -1)
5252
selected_lora_scalings = CtxGatherFuncCB.apply(
53-
self.lora_scalings, lora_ids, other_indices_s
53+
self.lora_scalings, lora_ids, other_indices_s, self.lora_scalings.shape[2]
5454
) # <num_loras, 1, 1, 1>
5555

5656
selected_lora_a_weights = selected_lora_a_weights.squeeze(1)

0 commit comments

Comments
 (0)