Skip to content

Commit ea320ed

Browse files
committed
fixed kv cache shape
Signed-off-by: Onkar Chougule <ochougul@qti.qualcomm.com>
1 parent 128b2c9 commit ea320ed

File tree

1 file changed

+3
-3
lines changed

1 file changed

+3
-3
lines changed

QEfficient/transformers/models/modeling_auto.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2343,9 +2343,6 @@ def export(
23432343
bs: int = constants.ONNX_EXPORT_EXAMPLE_BATCH_SIZE
23442344
seq_len: int = constants.ONNX_EXPORT_EXAMPLE_SEQ_LEN
23452345
fbs: int = constants.ONNX_EXPORT_EXAMPLE_FBS
2346-
kv_cache_shape = get_padding_shape_from_config(
2347-
self.model.config, fbs if self.continuous_batching else bs, seq_len
2348-
)
23492346
if prefill_only:
23502347
assert not self.continuous_batching, "prefill_only=True is not supported with continuous_batching=True"
23512348

@@ -2372,6 +2369,9 @@ def export(
23722369
self.hash_params.pop("prefill_only", None)
23732370
self.hash_params.pop("num_blocks", None)
23742371

2372+
kv_cache_shape = get_padding_shape_from_config(
2373+
self.model.config, fbs if self.continuous_batching else bs, seq_len
2374+
)
23752375
example_inputs = {
23762376
"input_ids": torch.zeros((bs, seq_len), dtype=torch.int64),
23772377
"position_ids": torch.arange(seq_len, dtype=torch.int64).view(1, seq_len).repeat(bs, 1),

0 commit comments

Comments
 (0)