Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 13 additions & 1 deletion lingbot_map/aggregator/stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,10 @@ def _get_flashinfer_manager(self, device, dtype, tokens_per_frame=None):

Args:
device: Device for cache tensors.
dtype: Data type for cache tensors.
dtype: Data type for cache tensors. If fp32 (e.g. tokens kept fp32 by
an autocast-exempt op like LayerNorm), prefer the aggregator's
parameter dtype, which reflects the inference dtype chosen at
model load time.
tokens_per_frame: Actual number of tokens per frame (patches + specials).
If None, falls back to assuming square images of self.img_size.
"""
Expand All @@ -212,6 +215,15 @@ def _get_flashinfer_manager(self, device, dtype, tokens_per_frame=None):
tokens_per_frame = (self.img_size // self.patch_size) ** 2 + self.num_special_tokens
# max_num_frames: scale + window + headroom
max_num_frames = self.kv_cache_scale_frames + self.kv_cache_sliding_window + 16

if dtype is None or dtype == torch.float32:
try:
param_dtype = next(self.parameters()).dtype
if param_dtype != torch.float32:
dtype = param_dtype
except StopIteration:
pass

self.kv_cache_manager = FlashInferKVCacheManager(
num_blocks=self.depth,
max_num_frames=max_num_frames,
Expand Down
22 changes: 20 additions & 2 deletions lingbot_map/layers/flashinfer_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,20 @@
FLASHINFER_AVAILABLE = False


def _default_flashinfer_dtype(device: torch.device) -> torch.dtype:
"""Pick a FlashInfer-compatible dtype based on CUDA compute capability.

bfloat16 requires SM80+ (Ampere). On SM<80 (e.g. Turing SM75 / Titan RTX)
only fp16 FlashInfer kernels are available.
"""
if torch.cuda.is_available():
idx = device.index if (isinstance(device, torch.device) and device.index is not None) else None
major, _ = torch.cuda.get_device_capability(idx)
if major >= 8:
return torch.bfloat16
return torch.float16


class FlashInferKVCacheManager:
"""
Two-stream paged KV cache: patch pages (recyclable) + special pages (append-only).
Expand Down Expand Up @@ -122,8 +136,12 @@ def __init__(
if force_fp32:
self.dtype = torch.float32
else:
if dtype == torch.float32:
dtype = torch.bfloat16
# FlashInfer FA2 only supports fp16/bf16. If the caller passed fp32
# (or None) we have to pick one. bfloat16 needs SM80+ (Ampere); on
# SM<80 (Turing/Volta, e.g. Titan RTX SM75) bf16 kernels won't run,
# so fall back to fp16 there.
if dtype is None or dtype == torch.float32:
dtype = _default_flashinfer_dtype(device)
self.dtype = dtype
self.device = device

Expand Down