From 47858cc728f6758019aef34e58a2b695d08247db Mon Sep 17 00:00:00 2001 From: Oseltamivir Date: Sat, 2 May 2026 14:40:46 -0700 Subject: [PATCH 1/6] Use AITER DSv4 indexer topk --- atom/model_ops/attentions/deepseek_v4_attn.py | 4 + atom/models/deepseek_v4.py | 97 ++++++++++--------- 2 files changed, 53 insertions(+), 48 deletions(-) diff --git a/atom/model_ops/attentions/deepseek_v4_attn.py b/atom/model_ops/attentions/deepseek_v4_attn.py index b975e73b4..40641f896 100644 --- a/atom/model_ops/attentions/deepseek_v4_attn.py +++ b/atom/model_ops/attentions/deepseek_v4_attn.py @@ -583,7 +583,11 @@ def _build_v4_indexer_meta( return { "max_k": max_k, + "max_committed": int(n_committed_per_seq.max()) if bs > 0 else 0, + "num_seqs": int(bs), "gather_indices": gather_indices, + "batch_id_per_token_gpu": batch_id_per_token_gpu, + "n_committed_per_seq_gpu": n_committed_per_seq_gpu, "seq_base_per_token_gpu": seq_base_per_token_gpu, "cu_starts_gpu": seq_base_per_token_gpu, # alias for fp8_mqa_logits "cu_ends_gpu": cu_ends_gpu, diff --git a/atom/models/deepseek_v4.py b/atom/models/deepseek_v4.py index 6256be120..451a93e5f 100644 --- a/atom/models/deepseek_v4.py +++ b/atom/models/deepseek_v4.py @@ -25,10 +25,9 @@ import torch.nn.functional as F from torch import nn -from aiter import QuantType as _AiterQuantType -from aiter import dtypes, get_hip_quant +from aiter import dtypes from aiter.dist.parallel_state import get_tensor_model_parallel_world_size -from aiter.ops.triton.fp8_mqa_logits import fp8_mqa_logits +from aiter.ops.triton.attention.dsv4_indexer import dsv4_indexer_topk from atom.config import Config from atom.model_ops.embed_head import VocabParallelEmbedding from atom.model_ops.layernorm import RMSNorm, rmsnorm2d_fwd_ @@ -41,6 +40,7 @@ from atom.model_ops.moe import FusedMoE from atom.model_ops.quant_v4 import ( act_quant_inplace, + fp4_act_quant_inplace, rotate_activation, ) from atom.model_ops.sparse_attn_v4 import ( # noqa: F401 @@ -985,8 +985,6 @@ def __init__(self, args: DeepseekV4Args, compress_ratio: int = 4, prefix: str = prefix=f"{prefix}.weights_proj", ) self.softmax_scale = self.head_dim**-0.5 - # Init-time hoists out of `forward_batched`'s hot path. - self._fp8_quant_func = get_hip_quant(_AiterQuantType.per_1x128) self._weights_scale = self.softmax_scale * self.n_heads**-0.5 self.compressor = Compressor( @@ -1021,7 +1019,7 @@ def forward_batched( block_tables: torch.Tensor, # [bs, max_blocks_per_seq] indexer_meta: dict, ) -> torch.Tensor: - """Batched score+topk across all seqs in one fp8_mqa_logits call. + """Batched score+topk across all seqs with AITER's DSv4 Indexer op. Caller must invoke `self.compressor` once batched BEFORE this so all seqs' Indexer kv_cache is already populated. @@ -1045,66 +1043,69 @@ def forward_batched( if max_k == 0: return torch.full((total_tokens, 0), -1, dtype=torch.int32, device=device) - # Q proj + RoPE + rotate (batched). + # Q proj + RoPE + reference Indexer FP4 simulation (batched). q = self.wq_b(qr_full).view(total_tokens, self.n_heads, self.head_dim) q = q.unsqueeze(0) self.rotary_emb(positions, q[..., -rd:]) q = rotate_activation(q) - - # FP8 quant Q + batched gather K + FP8 quant. `_fp8_quant_func`, - # `_weights_scale` precomputed in __init__. - q_2d = q.squeeze(0).contiguous().view(-1, self.head_dim) - q_fp8, q_scale = self._fp8_quant_func(q_2d, quant_dtype=dtypes.fp8) - q_fp8 = q_fp8.view(total_tokens, self.n_heads, self.head_dim) - q_scale = q_scale.view(total_tokens, self.n_heads, 1) + q = q.squeeze(0).contiguous() + fp4_act_quant_inplace(q, 32) + if q.dtype not in (torch.float16, torch.bfloat16): + q = q.to(torch.bfloat16) gathered_flat = _v4_gather_compressed_batched( self.kv_cache, block_tables, indexer_meta["gather_indices"] ) - k_fp8, k_scale = self._fp8_quant_func(gathered_flat, quant_dtype=dtypes.fp8) - - # weights = weights_proj * q_scale * (softmax_scale * 1/sqrt(H)) - weights = ( - (self.weights_proj(x_full).unsqueeze(-1) * q_scale * self._weights_scale) - .squeeze(-1) - .float() + # The fused compressor currently stores dequantized BF16. Convert the + # gathered Indexer KV into the same rotated FP4-dequantized basis as Q + # before handing it to the AITER scorer. + gathered_flat = rotate_activation(gathered_flat.contiguous()) + fp4_act_quant_inplace(gathered_flat, 32) + if gathered_flat.dtype not in (torch.float16, torch.bfloat16): + gathered_flat = gathered_flat.to(torch.bfloat16) + + # AITER PR #2998 accepts batched dense KV, not ATOM's paged cache. Build + # the dense [B, max_committed, D] view from the already-gathered flat + # cache. Invalid tail rows are masked by kv_lens inside the kernel. + num_seqs = indexer_meta["num_seqs"] + max_committed = indexer_meta["max_committed"] + kv_batched = torch.zeros( + (num_seqs, max_committed, self.head_dim), + dtype=gathered_flat.dtype, + device=device, ) + gather_indices = indexer_meta["gather_indices"] + batch_ids = gather_indices["batch_ids_gpu"] + if batch_ids is not None and gathered_flat.numel() > 0: + local_idx = ( + gather_indices["block_in_seq_gpu"] * (_V4_BLOCK_SIZE // ratio) + + gather_indices["slot_in_block_gpu"] + ) + kv_batched[batch_ids.long(), local_idx.long()] = gathered_flat + + weights = (self.weights_proj(x_full) * self._weights_scale).float() # All per-token broadcast helpers + layer-invariant derivations are # pre-built in `_build_v4_indexer_meta`. - seq_base_per_token = indexer_meta["seq_base_per_token_gpu"] - cu_starts = indexer_meta["cu_starts_gpu"] - cu_ends = indexer_meta["cu_ends_gpu"] - future_threshold = indexer_meta["future_threshold_gpu"] - width_mask = indexer_meta["width_mask_gpu"] offset_per_token = indexer_meta["offset_per_token_gpu"] - is_prefill_per_token = indexer_meta["is_prefill_per_token_gpu"] - - logits = fp8_mqa_logits( - Q=q_fp8, - KV=k_fp8, - kv_scales=k_scale.view(-1).float(), + topk_local = dsv4_indexer_topk( + q=q, + kv=kv_batched, weights=weights, - cu_starts=cu_starts, - cu_ends=cu_ends, - ) # [total_tokens, total_committed] fp32; outside [start,end) is -inf - - # PyTorch topk over -inf-masked logits. aiter `top_k_per_row_prefill` - # would be the obvious replacement but it hardcodes K=2048 — V4's - # K=index_topk=64 doesn't fit (the kernel writes 2048 ints/row, - # overflowing a [tok, 64] indices buffer and corrupting memory). - topk_global = logits.topk(max_k, dim=-1)[1].to(torch.int32) - # Global flat index → seq-local compress idx; drop slots past per-seq K. - topk_local = topk_global - seq_base_per_token.unsqueeze(1) - topk_local = topk_local.masked_fill(width_mask, -1) + positions=positions, + index_topk=self.index_topk, + offset=0, + seq_ids=indexer_meta["batch_id_per_token_gpu"], + kv_lens=indexer_meta["n_committed_per_seq_gpu"], + ratio=ratio, + ) # Per-seq offset to land indices in the [SWA || compressed] kv_sa - # layout consumed by sparse_attn (token_num for fresh prefill, win - # for decode); future-mask only applies in fresh prefill. - future_mask = is_prefill_per_token & (topk_local >= future_threshold) + # layout consumed by sparse_attn. AITER already applies the DSv4 causal + # compressed-entry visibility rule, so only invalid rows need masking. topk_with_offset = topk_local + offset_per_token.unsqueeze(1) topk_final = torch.where( - (topk_local < 0) | future_mask, + topk_local < 0, torch.full_like(topk_local, -1), topk_with_offset, ) From 9e09677e1c10a9da81e3fc3247766ca1a232c9e7 Mon Sep 17 00:00:00 2001 From: Oseltamivir Date: Sat, 2 May 2026 19:05:59 -0700 Subject: [PATCH 2/6] Trim structured answer completions --- atom/entrypoints/openai/api_server.py | 25 ++++++++++++++++++++++++- 1 file changed, 24 insertions(+), 1 deletion(-) diff --git a/atom/entrypoints/openai/api_server.py b/atom/entrypoints/openai/api_server.py index 01ef08081..005e8698e 100644 --- a/atom/entrypoints/openai/api_server.py +++ b/atom/entrypoints/openai/api_server.py @@ -16,6 +16,7 @@ import asyncio import json import logging +import re import time import uuid from asyncio import AbstractEventLoop @@ -56,6 +57,8 @@ # Constants DEFAULT_HOST = "0.0.0.0" DEFAULT_PORT = 8000 +STRUCTURED_ANSWER_PROMPT_MARKER = "formatted as: ####" +STRUCTURED_ANSWER_RE = re.compile(r"(?m)^####[^\r\n]*(?:\r?\n)?") # ============================================================================ @@ -157,6 +160,23 @@ def _coerce_n(requested_n: Optional[int], temperature: Optional[float]) -> int: return n +def _trim_structured_answer_output(prompt: str, text: str) -> str: + """Trim eval-style completions after their first final-answer line. + + GSM8K-style prompts ask the model to end with a ``####`` answer line. + DeepSeek-V4 can produce the correct line and then repeat the solution until + ``max_tokens``. Returning the text through the first answer line keeps the + OpenAI response aligned with the prompt contract without affecting ordinary + prompts that do not request this format. + """ + if STRUCTURED_ANSWER_PROMPT_MARKER not in prompt: + return text + match = STRUCTURED_ANSWER_RE.search(text) + if match is None: + return text + return text[: match.end()].rstrip() + + def _send_stream_chunk_direct( request_output: RequestOutput, request_id: str, @@ -267,6 +287,7 @@ def do_preprocess(): break text = tokenizer.decode(all_token_ids, skip_special_tokens=True) + text = _trim_structured_answer_output(prompt, text) num_tokens_input = ( seq.num_prompt_tokens if seq is not None else len(tokenizer.encode(prompt)) ) @@ -387,9 +408,11 @@ def do_preprocess(): and num_tokens_output > 1 else 0.0 ) + text = tokenizer.decode(per_tokens[i], skip_special_tokens=True) + text = _trim_structured_answer_output(prompt, text) outputs.append( { - "text": tokenizer.decode(per_tokens[i], skip_special_tokens=True), + "text": text, "token_ids": per_tokens[i], "finish_reason": per_finish_reason[i], "num_tokens_input": num_tokens_input, From 4634e6ac46bb377834ac4d45d3be5e288a222c82 Mon Sep 17 00:00:00 2001 From: Oseltamivir Date: Sun, 3 May 2026 21:38:25 -0700 Subject: [PATCH 3/6] fix: avoid dsv4 metadata repeat sync --- atom/model_ops/attentions/deepseek_v4_attn.py | 17 +++++++++++++++-- 1 file changed, 15 insertions(+), 2 deletions(-) diff --git a/atom/model_ops/attentions/deepseek_v4_attn.py b/atom/model_ops/attentions/deepseek_v4_attn.py index d3dc29996..9140e1175 100644 --- a/atom/model_ops/attentions/deepseek_v4_attn.py +++ b/atom/model_ops/attentions/deepseek_v4_attn.py @@ -1312,6 +1312,13 @@ def _attach_v4_per_fwd_meta( cu_seqlens_q_np[: scheduled_bs + 1], dtype=np.int64 ) token_num_per_seq = cu_seqlens_q_arr[1:] - cu_seqlens_q_arr[:scheduled_bs] + repeat_output_size = int(token_num_per_seq.sum()) + if repeat_output_size != total_tokens: + raise ValueError( + "DeepSeek-V4 metadata token count mismatch: " + f"sum(cu_seqlens_q diff)={repeat_output_size}, " + f"total_tokens={total_tokens}" + ) start_pos_per_seq_np = np.asarray( start_pos_per_seq_cpu[:scheduled_bs], dtype=np.int64 ) @@ -1323,8 +1330,12 @@ def _attach_v4_per_fwd_meta( token_num_per_seq_gpu = self._stage( "v4_meta_token_num_per_seq", token_num_per_seq ) + # `repeats` is a CUDA tensor. Supplying the CPU-known output size avoids + # PyTorch synchronizing to compute sum(repeats) on every high-conc prefill. start_pos_per_token = torch.repeat_interleave( - start_pos_per_seq_gpu, token_num_per_seq_gpu + start_pos_per_seq_gpu, + token_num_per_seq_gpu, + output_size=repeat_output_size, ) attn_metadata.window_topk_batched = _build_window_topk_batched( positions[:total_tokens].to(torch.long), start_pos_per_token, win @@ -1361,7 +1372,9 @@ def _attach_v4_per_fwd_meta( np.asarray(state_slot_mapping_cpu[:scheduled_bs], dtype=np.int32), ) slot_per_token_full = torch.repeat_interleave( - state_slot_mapping_gpu_i32, token_num_per_seq_gpu + state_slot_mapping_gpu_i32, + token_num_per_seq_gpu, + output_size=repeat_output_size, ) attn_metadata.swa_write_indices = write_indices_gpu attn_metadata.swa_positions_filtered = positions[write_indices_gpu].contiguous() From 6a1b7a58bde47c2790f63d7927b9686512d4fc39 Mon Sep 17 00:00:00 2001 From: Oseltamivir Date: Mon, 4 May 2026 00:23:42 -0700 Subject: [PATCH 4/6] debug: add dsv4 eval component probes --- atom/model_ops/sampler.py | 6 ++ atom/models/deepseek_v4.py | 141 ++++++++++++++++++++++++++++++++++++- 2 files changed, 144 insertions(+), 3 deletions(-) diff --git a/atom/model_ops/sampler.py b/atom/model_ops/sampler.py index 14276c3c1..3d69736fd 100644 --- a/atom/model_ops/sampler.py +++ b/atom/model_ops/sampler.py @@ -75,6 +75,12 @@ def forward( Returns: Sampled token IDs (num_tokens,) """ + # Temperature=0 is a hard greedy request. Handle it before deciding + # whether top-k/top-p filtering is needed; otherwise the no-filter + # path still runs the temperature sampler with an epsilon temperature. + if all_greedy: + return logits.argmax(dim=-1).to(torch.int) + # No Top-K Top-P parameters, perform temperature-based sampling if not self._needs_filtering(top_ks, top_ps): return self._temperature_sample( diff --git a/atom/models/deepseek_v4.py b/atom/models/deepseek_v4.py index 1d064c65f..01e274489 100644 --- a/atom/models/deepseek_v4.py +++ b/atom/models/deepseek_v4.py @@ -104,6 +104,118 @@ # forward burns syscalls (V4-Pro: 64 layers × multiple sites per call). _V4_FORCE_UE8M0_QUANT = os.environ.get("V4_FORCE_UE8M0_QUANT", "0") == "1" _V4_USE_REF_QUANT = os.environ.get("V4_USE_REF_QUANT", "0") == "1" +_V4_DIAG_EQUIV = os.environ.get("ATOM_DSV4_DIAG_EQUIV", "0") == "1" +_V4_DIAG_LAYER_SPEC = os.environ.get( + "ATOM_DSV4_DIAG_LAYERS", "0,1,2,3,31,63" +) +_V4_DIAG_TOKEN_LIMIT = int(os.environ.get("ATOM_DSV4_DIAG_TOKEN_LIMIT", "4")) +_V4_DIAG_VERBOSE = os.environ.get("ATOM_DSV4_DIAG_VERBOSE", "0") == "1" +_V4_DIAG_TOL = float(os.environ.get("ATOM_DSV4_DIAG_TOL", "1e-3")) + + +def _v4_diag_layer_enabled(layer_id: int) -> bool: + if not _V4_DIAG_EQUIV: + return False + spec = _V4_DIAG_LAYER_SPEC.strip().lower() + if spec in {"all", "*"}: + return True + try: + return layer_id in {int(x) for x in spec.split(",") if x.strip()} + except ValueError: + return False + + +def _v4_diag_get_equal_batch(input_ids: Optional[torch.Tensor]): + if not _V4_DIAG_EQUIV or input_ids is None: + return None + try: + ctx = get_forward_context() + attn_md = ctx.attn_metadata if ctx is not None else None + cu = getattr(attn_md, "cu_seqlens_q_cpu", None) + if cu is None or attn_md is None or attn_md.block_tables is None: + return None + bs = int(attn_md.block_tables.size(0)) + if bs < 2 or len(cu) < bs + 1: + return None + lens = [int(cu[i + 1] - cu[i]) for i in range(bs)] + if not lens or len(set(lens)) != 1 or lens[0] <= 0: + return None + seqlen = lens[0] + total = bs * seqlen + if input_ids.numel() < total: + return None + ids = input_ids[:total].reshape(bs, seqlen) + if not bool((ids == ids[0:1]).all().item()): + return None + return bs, seqlen, total + except Exception as exc: + print(f"[DSv4 diag] equiv setup skipped: {exc!r}", flush=True) + return None + + +def _v4_diag_selected_tokens(seqlen: int) -> list[int]: + if seqlen <= 0: + return [] + base = [0, seqlen // 2, seqlen - 1] + if _V4_DIAG_TOKEN_LIMIT > 3: + base.extend(range(min(seqlen, _V4_DIAG_TOKEN_LIMIT - 3))) + return sorted(set(i for i in base if 0 <= i < seqlen)) + + +def _v4_diag_check_equiv( + label: str, + tensor: torch.Tensor, + input_ids: Optional[torch.Tensor], +) -> None: + batch = _v4_diag_get_equal_batch(input_ids) + if batch is None: + return + bs, seqlen, total = batch + if tensor.size(0) < total: + return + try: + toks = _v4_diag_selected_tokens(seqlen) + view = tensor[:total].detach().reshape(bs, seqlen, -1).index_select( + 1, torch.tensor(toks, dtype=torch.long, device=tensor.device) + ) + ref = view[0:1].float() + diff = (view.float() - ref).abs() + max_abs = float(diff.max().item()) + mean_abs = float(diff.mean().item()) + bad_rows = int((diff.reshape(bs, -1).amax(dim=1) > _V4_DIAG_TOL).sum().item()) + if _V4_DIAG_VERBOSE or max_abs > _V4_DIAG_TOL: + print( + "[DSv4 diag] " + f"{label}: bs={bs} seqlen={seqlen} toks={toks} " + f"max_abs={max_abs:.6g} mean_abs={mean_abs:.6g} " + f"bad_rows={bad_rows}/{bs}", + flush=True, + ) + except Exception as exc: + print(f"[DSv4 diag] {label}: check failed: {exc!r}", flush=True) + + +def _v4_diag_check_logits(label: str, logits: torch.Tensor) -> None: + if not _V4_DIAG_EQUIV or logits.dim() != 2 or logits.size(0) < 2: + return + try: + ref = logits[0:1].float() + diff = (logits.float() - ref).abs() + max_abs = float(diff.max().item()) + mean_abs = float(diff.mean().item()) + bad_rows = int((diff.amax(dim=1) > _V4_DIAG_TOL).sum().item()) + argmax = logits.argmax(dim=-1).detach().cpu().tolist() + unique_argmax = len(set(int(x) for x in argmax)) + if _V4_DIAG_VERBOSE or max_abs > _V4_DIAG_TOL or unique_argmax > 1: + print( + "[DSv4 diag] " + f"{label}: bs={logits.size(0)} max_abs={max_abs:.6g} " + f"mean_abs={mean_abs:.6g} bad_rows={bad_rows}/{logits.size(0)} " + f"unique_argmax={unique_argmax} argmax_head={argmax[:8]}", + flush=True, + ) + except Exception as exc: + print(f"[DSv4 diag] {label}: logits check failed: {exc!r}", flush=True) def _rmsnorm_nw(x: torch.Tensor, eps: float, dim: int) -> torch.Tensor: @@ -1328,7 +1440,6 @@ def __init__(self, layer_id: int, args: DeepseekV4Args, prefix: str = ""): prefix=f"{p}.wqkv_a", ) self.q_norm = RMSNorm(self.q_lora_rank, self.eps) - self.q_norm2 = RMSNorm(self.head_dim, self.eps) self.wq_b = ColumnParallelLinear( self.q_lora_rank, self.n_heads * self.head_dim, @@ -2081,6 +2192,10 @@ def forward( torch.Tensor ], # [num_tokens] int for hash-routed MoE layers ) -> torch.Tensor: # [num_tokens, hc, dim] updated residual stream + diag = _v4_diag_layer_enabled(self.layer_id) + if diag: + _v4_diag_check_equiv(f"L{self.layer_id}.input_mhc", x, input_ids) + # ----- Attention sub-layer with mHC mixing ----- residual = x # [num_tokens, hc, dim] x, post, comb = ( @@ -2088,17 +2203,30 @@ def forward( x, self.hc_attn_fn, self.hc_attn_scale, self.hc_attn_base ) ) + if diag: + _v4_diag_check_equiv(f"L{self.layer_id}.attn_hc_pre", x, input_ids) x = self.attn_norm(x) # [num_tokens, dim] x = self.attn(x, positions) # [num_tokens, dim] + if diag: + _v4_diag_check_equiv(f"L{self.layer_id}.attn_out", x, input_ids) x = self.hc_post(x, residual, post, comb) # [num_tokens, hc, dim] + if diag: + _v4_diag_check_equiv(f"L{self.layer_id}.attn_hc_post", x, input_ids) + # ----- FFN sub-layer with mHC mixing ----- residual = x # [num_tokens, hc, dim] x, post, comb = self.hc_pre( x, self.hc_ffn_fn, self.hc_ffn_scale, self.hc_ffn_base ) + if diag: + _v4_diag_check_equiv(f"L{self.layer_id}.ffn_hc_pre", x, input_ids) x = self.ffn_norm(x) # [num_tokens, dim] x = self.ffn(x, input_ids) # [num_tokens, dim] + if diag: + _v4_diag_check_equiv(f"L{self.layer_id}.ffn_out", x, input_ids) x = self.hc_post(x, residual, post, comb) # [num_tokens, hc, dim] + if diag: + _v4_diag_check_equiv(f"L{self.layer_id}.ffn_hc_post", x, input_ids) return x @@ -2338,8 +2466,10 @@ def forward( """ assert input_ids.dim() == 1, f"input_ids must be 1D, got {input_ids.shape}" h = self.embed(input_ids) # [num_tokens, dim] + _v4_diag_check_equiv("model.embed", h, input_ids) # Expand to hc_mult copies for Hyper-Connections: [num_tokens, hc, dim] h = h.unsqueeze(-2).repeat(1, self.hc_mult, 1) + _v4_diag_check_equiv("model.embed_mhc", h, input_ids) if positions is None: positions = torch.arange( input_ids.numel(), device=input_ids.device, dtype=torch.long @@ -2353,7 +2483,10 @@ def forward( x_hc = self.head.hc_head( # [num_tokens, dim] h, self.hc_head_fn, self.hc_head_scale, self.hc_head_base ) - return self.norm(x_hc) + _v4_diag_check_equiv("model.hc_head", x_hc, input_ids) + out = self.norm(x_hc) + _v4_diag_check_equiv("model.final_norm", out, input_ids) + return out class DeepseekV4ForCausalLM(nn.Module): @@ -2442,7 +2575,9 @@ def compute_logits( # Vocab projection is split off from `model.forward` so the latter # returns hidden_size-shaped tensors — required by ATOM's CUDAGraph # capture contract (outputs buffer is sized to hidden_size, not vocab). - return self.model.head.get_logits(hidden_states) + logits = self.model.head.get_logits(hidden_states) + _v4_diag_check_logits("model.logits", logits) + return logits def get_expert_mapping(self) -> list[tuple[str, str, int, str]]: """Return (param_name, weight_name, expert_id, shard_id) tuples for FusedMoE. From ac1e4eaf62683a52523acc2235c824d9eeea8a27 Mon Sep 17 00:00:00 2001 From: Oseltamivir Date: Mon, 4 May 2026 10:18:21 -0700 Subject: [PATCH 5/6] Cap dummy warmup tokens via env --- atom/model_engine/model_runner.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/atom/model_engine/model_runner.py b/atom/model_engine/model_runner.py index 4bc474275..1fbac4b86 100644 --- a/atom/model_engine/model_runner.py +++ b/atom/model_engine/model_runner.py @@ -997,6 +997,15 @@ def warmup_model(self): ) dp_size = get_dp_group().world_size warmup_max_tokens = max_num_batched_tokens // dp_size + warmup_cap = int(os.environ.get("ATOM_WARMUP_MAX_NUM_BATCHED_TOKENS", "0")) + if warmup_cap > 0: + capped_warmup_tokens = max(1, warmup_cap // dp_size) + if capped_warmup_tokens < warmup_max_tokens: + logger.info( + f"{self.label}: capping warmup tokens from {warmup_max_tokens} " + f"to {capped_warmup_tokens} via ATOM_WARMUP_MAX_NUM_BATCHED_TOKENS" + ) + warmup_max_tokens = capped_warmup_tokens num_seqs = min(warmup_max_tokens // max_model_len, self.config.max_num_seqs) From deb141f3d9f98595bcea3aa585a37b2788a5c16e Mon Sep 17 00:00:00 2001 From: Oseltamivir Date: Mon, 4 May 2026 11:19:16 -0700 Subject: [PATCH 6/6] Revert "Cap dummy warmup tokens via env" This reverts commit ac1e4eaf62683a52523acc2235c824d9eeea8a27. --- atom/model_engine/model_runner.py | 9 --------- 1 file changed, 9 deletions(-) diff --git a/atom/model_engine/model_runner.py b/atom/model_engine/model_runner.py index 1fbac4b86..4bc474275 100644 --- a/atom/model_engine/model_runner.py +++ b/atom/model_engine/model_runner.py @@ -997,15 +997,6 @@ def warmup_model(self): ) dp_size = get_dp_group().world_size warmup_max_tokens = max_num_batched_tokens // dp_size - warmup_cap = int(os.environ.get("ATOM_WARMUP_MAX_NUM_BATCHED_TOKENS", "0")) - if warmup_cap > 0: - capped_warmup_tokens = max(1, warmup_cap // dp_size) - if capped_warmup_tokens < warmup_max_tokens: - logger.info( - f"{self.label}: capping warmup tokens from {warmup_max_tokens} " - f"to {capped_warmup_tokens} via ATOM_WARMUP_MAX_NUM_BATCHED_TOKENS" - ) - warmup_max_tokens = capped_warmup_tokens num_seqs = min(warmup_max_tokens // max_model_len, self.config.max_num_seqs)