Skip to content

Commit fba1ac0

Browse files
committed
added ffn blocking and num blocks env variables
Signed-off-by: Onkar Chougule <ochougul@qti.qualcomm.com>
1 parent dbe2495 commit fba1ac0

File tree

2 files changed

+37
-20
lines changed

2 files changed

+37
-20
lines changed

QEfficient/transformers/models/gpt_oss/modeling_gpt_oss.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,8 @@ def __qeff_init__(self):
4747

4848
class QEffPrefillOnlyGptOssMLP(GptOssMLP):
4949
def forward(self, hidden: torch.Tensor):
50+
if os.environ.get("NUM_FFN_BLOCKS", None) is not None:
51+
return self.blocked_ffn_forward(hidden)
5052
B, S, H = hidden.shape
5153
T = B * S
5254
hidden = hidden.view(T, H)
@@ -118,7 +120,7 @@ def blocked_ffn_forward(self, hidden: torch.Tensor):
118120

119121
# ────────────────── allocate the output tensor ─────
120122
expert_out = hidden.new_zeros((T, H)) # accumulation buffer
121-
target_blocks = int(os.environ.get("NUM_BLOCKS", 1))
123+
target_blocks = int(os.environ.get("NUM_FFN_BLOCKS", 1))
122124
block_positions = []
123125
for j in range(target_blocks):
124126
block_positions.append(j * (T // target_blocks))

QEfficient/transformers/models/modeling_auto.py

Lines changed: 34 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -2314,6 +2314,35 @@ def get_model_config(self) -> dict:
23142314
"""
23152315
return self.model.config.__dict__
23162316

2317+
def get_seq_len_and_handle_specialized_prefill_model(self, prefill_seq_len: Optional[int] = None) -> int:
2318+
num_q_blocks = os.environ.get("NUM_Q_BLOCKS", None)
2319+
if num_q_blocks is None:
2320+
block_size = 128
2321+
if prefill_seq_len is None or prefill_seq_len % block_size != 0 or prefill_seq_len < 128:
2322+
raise ValueError(
2323+
f"When prefill_only=True, 'prefill_seq_len' must be explicitly set and divisible by block_size={block_size}. "
2324+
f"Or set `NUM_BLOCKS` ENV variable"
2325+
f"Received: prefill_seq_len={prefill_seq_len}"
2326+
)
2327+
2328+
num_q_blocks = prefill_seq_len // block_size
2329+
logger.warning(
2330+
f"Setting NUM_BLOCKS={num_q_blocks} used in attention Q-blocking for prefill_only model, please set ENV variable `NUM_BLOCKS` to override"
2331+
)
2332+
os.environ["NUM_Q_BLOCKS"] = num_q_blocks
2333+
2334+
num_ffn_blocks = os.environ.get("NUM_FFN_BLOCKS", None)
2335+
min_seq_len = int(max(num_q_blocks, num_ffn_blocks)) if num_ffn_blocks else num_q_blocks
2336+
2337+
self.prefill(True)
2338+
self.hash_params["prefill_only"] = True
2339+
self.hash_params["num_blocks"] = os.environ["NUM_BLOCKS"]
2340+
return (
2341+
min_seq_len
2342+
if min_seq_len > constants.ONNX_EXPORT_EXAMPLE_SEQ_LEN
2343+
else constants.ONNX_EXPORT_EXAMPLE_SEQ_LEN
2344+
)
2345+
23172346
def export(
23182347
self,
23192348
export_dir: Optional[str] = None,
@@ -2345,25 +2374,11 @@ def export(
23452374
fbs: int = constants.ONNX_EXPORT_EXAMPLE_FBS
23462375
if prefill_only:
23472376
assert not self.continuous_batching, "prefill_only=True is not supported with continuous_batching=True"
2348-
2349-
if self.model.config.model_type in SPECIALIZED_PREFILL_ONLY_MODEL_ARCH:
2350-
block_size = os.environ.get("BLOCK_SIZE", None)
2351-
if block_size is None:
2352-
block_size = 128
2353-
logger.warning(
2354-
"Setting BLOCK_SIZE=128 for prefill_only model, please set ENV variable `BLOCK_SIZE` to override"
2355-
)
2356-
if prefill_seq_len is None or prefill_seq_len % block_size != 0:
2357-
raise ValueError(
2358-
f"When prefill_only=True, 'prefill_seq_len' must be explicitly set and divisible by block_size={block_size}. "
2359-
f"Received: prefill_seq_len={prefill_seq_len}"
2360-
)
2361-
os.environ["NUM_BLOCKS"] = str(prefill_seq_len // block_size)
2362-
2363-
self.prefill(True)
2364-
self.hash_params["prefill_only"] = True
2365-
self.hash_params["num_blocks"] = os.environ["NUM_BLOCKS"]
2366-
seq_len = prefill_seq_len // block_size if (prefill_seq_len // block_size) > seq_len else seq_len
2377+
seq_len = (
2378+
self.get_seq_len_and_handle_specialized_prefill_model(prefill_seq_len)
2379+
if self.model.config.model_type in SPECIALIZED_PREFILL_ONLY_MODEL_ARCH
2380+
else seq_len
2381+
)
23672382
else:
23682383
self.prefill(False)
23692384
self.hash_params.pop("prefill_only", None)

0 commit comments

Comments
 (0)