77import inspect
88import json
99import logging
10+ import os
11+ import time
1012import types
1113
1214from functools import partial
@@ -413,6 +415,63 @@ def _tag_ios(self, node, fixed_point_type):
413415
414416 return quant_io_type
415417
418+ def _auto_tune_calibration_threads (self ):
419+ """Find the optimal thread count for calibration via quick microbenchmark.
420+
421+ AR1 decode calibration is SGEMV-dominated (memory-bandwidth-bound).
422+ The default thread count (os.cpu_count()) is typically far too high,
423+ causing massive OpenMP sync overhead. This runs a few forward passes
424+ at candidate thread counts and picks the fastest.
425+ """
426+ # Use sched_getaffinity when available — it respects cgroup/taskset
427+ # constraints (e.g. containers), unlike os.cpu_count() which returns
428+ # the host total regardless of pinning.
429+ available = (
430+ len (os .sched_getaffinity (0 ))
431+ if hasattr (os , "sched_getaffinity" )
432+ else (os .cpu_count () or 1 )
433+ )
434+ baseline = min (torch .get_num_threads (), available )
435+ # Sample fractions of the thread ceiling from low through the
436+ # bandwidth-saturation knee up to the current default.
437+ fractions = (1 / 8 , 1 / 4 , 3 / 8 , 1 / 2 , 2 / 3 , 3 / 4 , 1.0 )
438+ candidates = sorted (
439+ {1 , baseline } | {max (1 , round (baseline * f )) for f in fractions }
440+ )
441+ original = torch .get_num_threads ()
442+ best_threads , best_time = original , float ("inf" )
443+ try :
444+ for n_threads in candidates :
445+ torch .set_num_threads (n_threads )
446+ try :
447+ with torch .no_grad ():
448+ self .decoder (* self .export_input ) # warmup
449+ t0 = time .perf_counter ()
450+ for _ in range (3 ):
451+ self .decoder (* self .export_input )
452+ elapsed = time .perf_counter () - t0
453+ if elapsed < best_time :
454+ best_threads , best_time = n_threads , elapsed
455+ except Exception :
456+ logging .debug ("Auto-tune: threads=%d failed, skipping" , n_threads )
457+ continue
458+ finally :
459+ torch .set_num_threads (original )
460+ if best_time == float ("inf" ):
461+ logging .warning (
462+ "Auto-tune: all candidates %s failed, falling back to %d threads" ,
463+ candidates ,
464+ baseline ,
465+ )
466+ return baseline
467+ logging .info (
468+ "Auto-tune calibration threads: tested %s, best=%d (%.1fms/fwd)" ,
469+ candidates ,
470+ best_threads ,
471+ best_time / 3 * 1000 ,
472+ )
473+ return best_threads
474+
416475 def _calibrate (
417476 self ,
418477 model ,
@@ -559,6 +618,14 @@ def quantize(self, request: Request): # noqa: C901
559618 self .decoder , self .export_input , strict = True
560619 ).module ()
561620
621+ # Auto-tune thread count BEFORE prepare_pt2e so the benchmark
622+ # runs on the exported model without observers — no risk of
623+ # polluting observer state with synthetic inputs.
624+ if self .mode == Mode .DECODE or not self .model_args .use_kv_cache :
625+ calib_threads = getattr (self .control_args , "calibration_num_threads" , 0 )
626+ if calib_threads <= 0 :
627+ calib_threads = self ._auto_tune_calibration_threads ()
628+
562629 self .decoder = prepare_pt2e (self .decoder , quantizer )
563630 if self .apply_embedding :
564631 self .tok_embedding = prepare_pt2e (
@@ -567,14 +634,24 @@ def quantize(self, request: Request): # noqa: C901
567634
568635 # start calibration (only for kv mode or prefill mode without kv cache)
569636 if self .mode == Mode .DECODE or not self .model_args .use_kv_cache :
570- self ._calibrate (
571- model = self .decoder ,
572- tokenizer = data .tokenizer ,
573- event = "prepare_pt2e" ,
574- user_calibration_data = data .calibration_data .datasets ,
575- tok_embedding = self .tok_embedding ,
576- intermediate_outputs = intermediate_outputs ,
637+ original_threads = torch .get_num_threads ()
638+ torch .set_num_threads (calib_threads )
639+ logging .info (
640+ "Calibration using %d threads (was %d)" ,
641+ calib_threads ,
642+ original_threads ,
577643 )
644+ try :
645+ self ._calibrate (
646+ model = self .decoder ,
647+ tokenizer = data .tokenizer ,
648+ event = "prepare_pt2e" ,
649+ user_calibration_data = data .calibration_data .datasets ,
650+ tok_embedding = self .tok_embedding ,
651+ intermediate_outputs = intermediate_outputs ,
652+ )
653+ finally :
654+ torch .set_num_threads (original_threads )
578655 else :
579656 # one dummy inference to remove affine observer
580657 # error happened in convert_pt2e
0 commit comments