Skip to content

Commit 3906b58

Browse files
Qualcomm AI Engine Direct - calibration thread auto-tuning (#18184)
1 parent 4a34bc4 commit 3906b58

2 files changed

Lines changed: 93 additions & 7 deletions

File tree

examples/qualcomm/oss_scripts/llama/llama.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -528,10 +528,19 @@ def _build_parser():
528528

529529
parser.add_argument("-v", "--verbose", action="store_true")
530530

531+
parser.add_argument(
532+
"--calibration_num_threads",
533+
type=int,
534+
default=0,
535+
help="Thread count for calibration forward passes. 0 = auto-tune (default).",
536+
)
537+
531538
return parser
532539

533540

534541
def export_llama(args) -> None:
542+
if args.calibration_num_threads < 0:
543+
raise ValueError("--calibration_num_threads must be >= 0")
535544
if args.compile_only and args.pre_gen_pte:
536545
raise RuntimeError("Cannot set both compile_only and pre_gen_pte as true")
537546
if (TASKS_EVAL or SQNR_EVAL) in args.eval_methods and args.model_mode not in {

examples/qualcomm/oss_scripts/llama/wrappers/llm_wrappers.py

Lines changed: 84 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@
77
import inspect
88
import json
99
import logging
10+
import os
11+
import time
1012
import types
1113

1214
from functools import partial
@@ -413,6 +415,63 @@ def _tag_ios(self, node, fixed_point_type):
413415

414416
return quant_io_type
415417

418+
def _auto_tune_calibration_threads(self):
419+
"""Find the optimal thread count for calibration via quick microbenchmark.
420+
421+
AR1 decode calibration is SGEMV-dominated (memory-bandwidth-bound).
422+
The default thread count (os.cpu_count()) is typically far too high,
423+
causing massive OpenMP sync overhead. This runs a few forward passes
424+
at candidate thread counts and picks the fastest.
425+
"""
426+
# Use sched_getaffinity when available — it respects cgroup/taskset
427+
# constraints (e.g. containers), unlike os.cpu_count() which returns
428+
# the host total regardless of pinning.
429+
available = (
430+
len(os.sched_getaffinity(0))
431+
if hasattr(os, "sched_getaffinity")
432+
else (os.cpu_count() or 1)
433+
)
434+
baseline = min(torch.get_num_threads(), available)
435+
# Sample fractions of the thread ceiling from low through the
436+
# bandwidth-saturation knee up to the current default.
437+
fractions = (1 / 8, 1 / 4, 3 / 8, 1 / 2, 2 / 3, 3 / 4, 1.0)
438+
candidates = sorted(
439+
{1, baseline} | {max(1, round(baseline * f)) for f in fractions}
440+
)
441+
original = torch.get_num_threads()
442+
best_threads, best_time = original, float("inf")
443+
try:
444+
for n_threads in candidates:
445+
torch.set_num_threads(n_threads)
446+
try:
447+
with torch.no_grad():
448+
self.decoder(*self.export_input) # warmup
449+
t0 = time.perf_counter()
450+
for _ in range(3):
451+
self.decoder(*self.export_input)
452+
elapsed = time.perf_counter() - t0
453+
if elapsed < best_time:
454+
best_threads, best_time = n_threads, elapsed
455+
except Exception:
456+
logging.debug("Auto-tune: threads=%d failed, skipping", n_threads)
457+
continue
458+
finally:
459+
torch.set_num_threads(original)
460+
if best_time == float("inf"):
461+
logging.warning(
462+
"Auto-tune: all candidates %s failed, falling back to %d threads",
463+
candidates,
464+
baseline,
465+
)
466+
return baseline
467+
logging.info(
468+
"Auto-tune calibration threads: tested %s, best=%d (%.1fms/fwd)",
469+
candidates,
470+
best_threads,
471+
best_time / 3 * 1000,
472+
)
473+
return best_threads
474+
416475
def _calibrate(
417476
self,
418477
model,
@@ -559,6 +618,14 @@ def quantize(self, request: Request): # noqa: C901
559618
self.decoder, self.export_input, strict=True
560619
).module()
561620

621+
# Auto-tune thread count BEFORE prepare_pt2e so the benchmark
622+
# runs on the exported model without observers — no risk of
623+
# polluting observer state with synthetic inputs.
624+
if self.mode == Mode.DECODE or not self.model_args.use_kv_cache:
625+
calib_threads = getattr(self.control_args, "calibration_num_threads", 0)
626+
if calib_threads <= 0:
627+
calib_threads = self._auto_tune_calibration_threads()
628+
562629
self.decoder = prepare_pt2e(self.decoder, quantizer)
563630
if self.apply_embedding:
564631
self.tok_embedding = prepare_pt2e(
@@ -567,14 +634,24 @@ def quantize(self, request: Request): # noqa: C901
567634

568635
# start calibration (only for kv mode or prefill mode without kv cache)
569636
if self.mode == Mode.DECODE or not self.model_args.use_kv_cache:
570-
self._calibrate(
571-
model=self.decoder,
572-
tokenizer=data.tokenizer,
573-
event="prepare_pt2e",
574-
user_calibration_data=data.calibration_data.datasets,
575-
tok_embedding=self.tok_embedding,
576-
intermediate_outputs=intermediate_outputs,
637+
original_threads = torch.get_num_threads()
638+
torch.set_num_threads(calib_threads)
639+
logging.info(
640+
"Calibration using %d threads (was %d)",
641+
calib_threads,
642+
original_threads,
577643
)
644+
try:
645+
self._calibrate(
646+
model=self.decoder,
647+
tokenizer=data.tokenizer,
648+
event="prepare_pt2e",
649+
user_calibration_data=data.calibration_data.datasets,
650+
tok_embedding=self.tok_embedding,
651+
intermediate_outputs=intermediate_outputs,
652+
)
653+
finally:
654+
torch.set_num_threads(original_threads)
578655
else:
579656
# one dummy inference to remove affine observer
580657
# error happened in convert_pt2e

0 commit comments

Comments
 (0)