From 7e0bcda1714ee513fde0a6ba9b8ac0e02df87d5e Mon Sep 17 00:00:00 2001 From: os-gabe Date: Tue, 28 Apr 2026 17:30:51 -0700 Subject: [PATCH] Make FlashInfer cache dtype hardware-aware (fix bf16 on SM<80) FlashInferKVCacheManager unconditionally mapped fp32 -> bf16, which fails at runtime on SM<80 (Turing/Volta, e.g. Titan RTX) where bf16 kernels aren't available. AggregatorStream also passed tokens.dtype, which autocast-exempt ops (LayerNorm) leak as fp32, so the bug fires even when demo.py selects fp16. - flashinfer_cache: hardware-aware fp32/None fallback (bf16 only on SM>=8). - aggregator/stream: prefer aggregator parameter dtype before falling through to the cache's default. SM>=80 behavior is unchanged. Co-Authored-By: Claude Opus 4.7 (1M context) --- lingbot_map/aggregator/stream.py | 14 +++++++++++++- lingbot_map/layers/flashinfer_cache.py | 22 ++++++++++++++++++++-- 2 files changed, 33 insertions(+), 3 deletions(-) diff --git a/lingbot_map/aggregator/stream.py b/lingbot_map/aggregator/stream.py index e442160..e31674b 100644 --- a/lingbot_map/aggregator/stream.py +++ b/lingbot_map/aggregator/stream.py @@ -200,7 +200,10 @@ def _get_flashinfer_manager(self, device, dtype, tokens_per_frame=None): Args: device: Device for cache tensors. - dtype: Data type for cache tensors. + dtype: Data type for cache tensors. If fp32 (e.g. tokens kept fp32 by + an autocast-exempt op like LayerNorm), prefer the aggregator's + parameter dtype, which reflects the inference dtype chosen at + model load time. tokens_per_frame: Actual number of tokens per frame (patches + specials). If None, falls back to assuming square images of self.img_size. """ @@ -212,6 +215,15 @@ def _get_flashinfer_manager(self, device, dtype, tokens_per_frame=None): tokens_per_frame = (self.img_size // self.patch_size) ** 2 + self.num_special_tokens # max_num_frames: scale + window + headroom max_num_frames = self.kv_cache_scale_frames + self.kv_cache_sliding_window + 16 + + if dtype is None or dtype == torch.float32: + try: + param_dtype = next(self.parameters()).dtype + if param_dtype != torch.float32: + dtype = param_dtype + except StopIteration: + pass + self.kv_cache_manager = FlashInferKVCacheManager( num_blocks=self.depth, max_num_frames=max_num_frames, diff --git a/lingbot_map/layers/flashinfer_cache.py b/lingbot_map/layers/flashinfer_cache.py index f4c6a48..07b13c1 100644 --- a/lingbot_map/layers/flashinfer_cache.py +++ b/lingbot_map/layers/flashinfer_cache.py @@ -52,6 +52,20 @@ FLASHINFER_AVAILABLE = False +def _default_flashinfer_dtype(device: torch.device) -> torch.dtype: + """Pick a FlashInfer-compatible dtype based on CUDA compute capability. + + bfloat16 requires SM80+ (Ampere). On SM<80 (e.g. Turing SM75 / Titan RTX) + only fp16 FlashInfer kernels are available. + """ + if torch.cuda.is_available(): + idx = device.index if (isinstance(device, torch.device) and device.index is not None) else None + major, _ = torch.cuda.get_device_capability(idx) + if major >= 8: + return torch.bfloat16 + return torch.float16 + + class FlashInferKVCacheManager: """ Two-stream paged KV cache: patch pages (recyclable) + special pages (append-only). @@ -122,8 +136,12 @@ def __init__( if force_fp32: self.dtype = torch.float32 else: - if dtype == torch.float32: - dtype = torch.bfloat16 + # FlashInfer FA2 only supports fp16/bf16. If the caller passed fp32 + # (or None) we have to pick one. bfloat16 needs SM80+ (Ampere); on + # SM<80 (Turing/Volta, e.g. Titan RTX SM75) bf16 kernels won't run, + # so fall back to fp16 there. + if dtype is None or dtype == torch.float32: + dtype = _default_flashinfer_dtype(device) self.dtype = dtype self.device = device