Skip to content

Commit 37f3681

Browse files
committed
include num_ffn_blocks in hash
Signed-off-by: Onkar Chougule <ochougul@qti.qualcomm.com>
1 parent fba1ac0 commit 37f3681

File tree

1 file changed

+11
-3
lines changed

1 file changed

+11
-3
lines changed

QEfficient/transformers/models/modeling_auto.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2329,14 +2329,22 @@ def get_seq_len_and_handle_specialized_prefill_model(self, prefill_seq_len: Opti
23292329
logger.warning(
23302330
f"Setting NUM_BLOCKS={num_q_blocks} used in attention Q-blocking for prefill_only model, please set ENV variable `NUM_BLOCKS` to override"
23312331
)
2332-
os.environ["NUM_Q_BLOCKS"] = num_q_blocks
2332+
os.environ["NUM_Q_BLOCKS"] = str(num_q_blocks)
2333+
num_q_blocks = int(num_q_blocks)
23332334

23342335
num_ffn_blocks = os.environ.get("NUM_FFN_BLOCKS", None)
2335-
min_seq_len = int(max(num_q_blocks, num_ffn_blocks)) if num_ffn_blocks else num_q_blocks
2336+
num_ffn_blocks = int(num_ffn_blocks) if num_ffn_blocks else num_ffn_blocks
2337+
min_seq_len = max(num_q_blocks, num_ffn_blocks) if num_ffn_blocks else num_q_blocks
2338+
if (num_ffn_blocks and min_seq_len % num_ffn_blocks != 0) or min_seq_len % num_q_blocks != 0:
2339+
raise ValueError(
2340+
f"Got NUM_FFN_BLOCKS={num_ffn_blocks} and NUM_Q_BLOCKS={num_q_blocks}, tried to set seq_len={min_seq_len} for export but,"
2341+
"seq_len is not divisible by either num_ffn_blocks or num_q_blocks, try chaning the values."
2342+
)
23362343

23372344
self.prefill(True)
23382345
self.hash_params["prefill_only"] = True
2339-
self.hash_params["num_blocks"] = os.environ["NUM_BLOCKS"]
2346+
self.hash_params["num_blocks"] = num_q_blocks
2347+
self.hash_params["num_ffn_blocks"] = num_ffn_blocks
23402348
return (
23412349
min_seq_len
23422350
if min_seq_len > constants.ONNX_EXPORT_EXAMPLE_SEQ_LEN

0 commit comments

Comments
 (0)