Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 8 additions & 29 deletions atom/model_loader/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,7 +271,8 @@ def have_shared_expert(name):
detect_fused_expert_fn = getattr(model, "detect_fused_expert_format", None)
get_fused_expert_mapping_fn = getattr(model, "get_fused_expert_mapping", None)

with concurrent.futures.ThreadPoolExecutor() as executor:
if True:
executor = None
futures = []
Comment on lines +275 to 276
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ [ruff] <F841> reported by reviewdog 🐶
Local variable executor is assigned to but never used

Suggested change
executor = None
futures = []
futures = []

disable_mmap = envs.ATOM_DISABLE_MMAP
Comment on lines 276 to 277
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ [ruff] <F841> reported by reviewdog 🐶
Local variable futures is assigned to but never used

Suggested change
futures = []
disable_mmap = envs.ATOM_DISABLE_MMAP
disable_mmap = envs.ATOM_DISABLE_MMAP

for name, weight_tensor in safetensors_weights_iterator(
Expand Down Expand Up @@ -346,11 +347,7 @@ def have_shared_expert(name):
except AttributeError:
continue
weight_loader = getattr(param, "weight_loader")
futures.append(
executor.submit(
weight_loader, param, weight_tensor, shard_idx
)
)
weight_loader(param, weight_tensor, shard_idx)
loaded_weights_record.add(prefix + param_name)
else:
# Checkpoint has separate weights, load into fused param
Expand All @@ -364,11 +361,7 @@ def have_shared_expert(name):
break
weight_loader = getattr(param, "weight_loader")
# weight_loader(param, weight_tensor, shard_id)
futures.append(
executor.submit(
weight_loader, param, weight_tensor, shard_id
)
)
weight_loader(param, weight_tensor, shard_id)
loaded_weights_record.add(prefix + param_name)
break
else:
Expand Down Expand Up @@ -438,16 +431,7 @@ def have_shared_expert(name):
matched = True
break
weight_loader = getattr(param, "weight_loader")
futures.append(
executor.submit(
weight_loader,
param,
weight_tensor,
name,
shard_id,
expert_id,
)
)
weight_loader(param, weight_tensor, name, shard_id, expert_id)
loaded_weights_record.add(prefix + name)
matched = True
break
Expand All @@ -461,9 +445,7 @@ def have_shared_expert(name):
weight_loader = getattr(
param, "weight_loader", default_weight_loader
)
futures.append(
executor.submit(weight_loader, param, weight_tensor)
)
weight_loader(param, weight_tensor)
loaded_weights_record.add(prefix + name)
else:
# Model doesn't have expert mapping, use generic loading
Expand All @@ -474,12 +456,9 @@ def have_shared_expert(name):
weight_loader = getattr(
param, "weight_loader", default_weight_loader
)
# weight_loader(param, weight_tensor)
futures.append(executor.submit(weight_loader, param, weight_tensor))
weight_loader(param, weight_tensor)
loaded_weights_record.add(prefix + name)
# Wait for all tasks to complete and raise any exceptions.
for future in concurrent.futures.as_completed(futures):
future.result()
pass # Weight loading done synchronously above

for _, module in model.named_modules():
if hasattr(module, "process_weights_after_loading"):
Expand Down
8 changes: 2 additions & 6 deletions atom/model_ops/activation.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
from typing import Optional
from torch import nn
import torch.nn.functional as F
from aiter import silu_and_mul
from atom.config import QuantizationConfig
from atom.quant_spec import LayerQuantConfig
from aiter.jit.utils.torch_guard import torch_compile_guard
Expand Down Expand Up @@ -105,8 +104,5 @@ def forward(
):
return mxfp4_act_mul_quant_fuse(x, shuffle=True)
else:
out = torch.empty(
[*x.shape[:-1], x.shape[-1] // 2], device=x.device, dtype=x.dtype
)
silu_and_mul(out, x)
return out
x, y = x.chunk(2, -1)
return F.silu(x) * y
54 changes: 54 additions & 0 deletions atom/model_ops/attention_mha.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from aiter.ops.triton.fused_kv_cache import fused_qk_rope_reshape_and_cache
from aiter.ops.triton.gluon.pa_decode_gluon import get_recommended_splits
from aiter.ops.triton.unified_attention import unified_attention
from aiter.jit.core import ENABLE_CK
from atom.config import get_current_atom_config
from atom.utils.forward_context import ForwardContext, get_forward_context
from torch import nn
Expand Down Expand Up @@ -500,6 +501,57 @@ def prefill_attention_triton(
attn_metadata = fwd_ctx.attn_metadata
block_tables = attn_metadata.block_tables

if block_tables is None:
# No paged KV cache block_table during prefill — treat each
# token as its own block (block_size=1) and build a fake
# block_table so we can go through unified_attention, which
# supports sinks and sliding window (context_attention_fwd
# does not).
num_seqs = attn_metadata.cu_seqlens_q.shape[0] - 1
seq_lens = attn_metadata.cu_seqlens_q[1:] - attn_metadata.cu_seqlens_q[:-1]
max_seq_len = int(attn_metadata.max_seqlen_q)

# fake block_table: [num_seqs, max_seq_len], each row is
# [start, start+1, ..., start+len-1, 0, 0, ...]
arange = torch.arange(max_seq_len, device=q.device, dtype=torch.int32)
starts = attn_metadata.cu_seqlens_q[:num_seqs].unsqueeze(1)
block_tables = starts + arange.unsqueeze(0)

# Reshape k/v as flash-layout cache with block_size=1:
# [total_tokens, num_kv_heads, head_dim] →
# [total_tokens, 1, num_kv_heads, head_dim]
k_cache_fake = k.unsqueeze(1)
v_cache_fake = v.unsqueeze(1)

o = torch.empty_like(q)
descale_shape = (num_seqs, k.shape[1])
sliding_window = (
(self.sliding_window - 1, 0)
if self.sliding_window is not None
else (-1, -1)
)
unified_attention(
q,
k_cache_fake,
v_cache_fake,
o,
cu_seqlens_q=attn_metadata.cu_seqlens_q,
seqused_k=seq_lens,
max_seqlen_q=max_seq_len,
max_seqlen_k=max_seq_len,
softmax_scale=self.scale,
causal=True,
alibi_slopes=None,
window_size=sliding_window,
block_table=block_tables,
softcap=0,
q_descale=None,
k_descale=self.kv_scale.expand(descale_shape),
v_descale=self.kv_scale.expand(descale_shape),
sinks=self.sinks,
)
return o

o = torch.empty_like(q)
descale_shape = (attn_metadata.cu_seqlens_q.shape[0] - 1, k.shape[1])
sliding_window = (
Expand Down Expand Up @@ -535,6 +587,8 @@ def dispatch_backend(self, fwd_ctx: ForwardContext):
ctx = fwd_ctx.context

if ctx.is_prefill:
if not ENABLE_CK and self.use_triton_attn:
return self.prefill_attention_triton
return self.prefill_attention
else:
if self.use_triton_attn:
Expand Down
91 changes: 51 additions & 40 deletions atom/model_ops/fused_moe_triton.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,18 @@

if has_triton_kernels():
try:
from triton_kernels.matmul_ogs import matmul_ogs
import triton_kernels.swiglu
from triton_kernels.matmul_ogs import (
FnSpecs,
FusedActivation,
PrecisionConfig,
matmul_ogs,
)
from triton_kernels.matmul_ogs_details.opt_flags import (
update_opt_flags_constraints,
reset_opt_flags_constraints,
)
from triton_kernels.routing import routing
from triton_kernels.matmul_ogs import PrecisionConfig
except (AttributeError, ImportError) as e:
logger.error(
"Failed to import Triton kernels. Please make sure your triton "
Expand All @@ -52,9 +61,9 @@ def _swizzle_mxfp4(quant_tensor, scale):
scale_layout_opts: dict[str, Any] = {}
value_layout = StridedLayout
if get_gfx() == "gfx950":
from triton_kernels.tensor_details.layout import GFX950MXScaleLayout
from triton_kernels.tensor_details.layout import CDNA4MXScaleLayout

scale_layout = GFX950MXScaleLayout
scale_layout = CDNA4MXScaleLayout
else:
scale_layout = StridedLayout

Expand Down Expand Up @@ -227,47 +236,49 @@ def triton_kernel_fused_experts(

gammas = routing_data.gate_scal if routing_data else None

# NOTE: We intentionally do NOT use the triton fused SwiGLU activation
# because it expects interleaved [gate0, up0, gate1, up1, ...] layout
# while our w13 weights produce concatenated [gate | up] output.
# It also uses a non-standard formula: s*sigmoid(alpha*s)*(linear+1)
# with alpha=1.702, which differs from the standard SiLU activation
# (x*sigmoid(x)*up) used by most MoE models.
# Instead, we compute the matmul without fused activation and apply
# standard silu(gate) * up manually.
raw_intermediate = torch.empty(
(batch_dim, M * topk, N),
device=hidden_states.device,
dtype=hidden_states.dtype,
act = FusedActivation(
FnSpecs("swiglu", triton_kernels.swiglu.swiglu_fn, ("alpha", "limit")),
(swiglu_alpha, swiglu_limit),
2,
)

matmul_ogs(
hidden_states,
w1,
w1_bias,
routing_data,
gather_indx=gather_indx,
precision_config=w13_precision_config,
gammas=gammas if apply_router_weight_on_input else None,
y=raw_intermediate,
# On CDNA4 (gfx950) with MXFP4 weights, triton_kernels auto-selects
# block_m=256/block_n=512 which exceeds the 160KB shared memory limit.
# Constrain to block_m=128/block_n=256 to stay within hardware limits.
needs_block_cap = (
get_gfx() == "gfx950"
and w13_precision_config is not None
and w13_precision_config.weight_scale is not None
)
if needs_block_cap:
update_opt_flags_constraints({"block_m": 128, "block_n": 256})

# Standard SiLU/SwiGLU activation: silu(gate) * up
raw_2d = raw_intermediate.view(M * topk, N)
gate = raw_2d[:, :half_N]
up = raw_2d[:, half_N:]
intermediate_cache[0] = torch.nn.functional.silu(gate) * up
try:
matmul_ogs(
hidden_states,
w1,
w1_bias,
routing_data,
gather_indx=gather_indx,
precision_config=w13_precision_config,
gammas=gammas if apply_router_weight_on_input else None,
fused_activation=act,
y=intermediate_cache,
)

matmul_ogs(
intermediate_cache.view(M * topk, half_N),
w2,
w2_bias,
routing_data,
scatter_indx=scatter_indx,
precision_config=w2_precision_config,
gammas=None if apply_router_weight_on_input else gammas,
y=output_tensor,
)
matmul_ogs(
intermediate_cache.view(M * topk, half_N),
w2,
w2_bias,
routing_data,
scatter_indx=scatter_indx,
precision_config=w2_precision_config,
gammas=None if apply_router_weight_on_input else gammas,
y=output_tensor,
)
finally:
if needs_block_cap:
reset_opt_flags_constraints()

output_tensor = output_tensor.view(M, K)
return output_tensor
6 changes: 4 additions & 2 deletions atom/model_ops/moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from aiter import ActivationType, QuantType, dtypes, get_hip_quant
from aiter.dist.parallel_state import get_dp_group, get_tp_group
from aiter.fused_moe import fused_moe
from aiter.jit.core import ENABLE_CK
from aiter.jit.utils.chip_info import get_gfx
from aiter.jit.utils.torch_guard import torch_compile_guard
from aiter.ops.shuffle import shuffle_scale_a16w4, shuffle_weight_a16w4
Expand Down Expand Up @@ -641,8 +642,9 @@ def __init__(self, quant_config: LayerQuantConfig, moe: FusedMoEConfig):
or self.quant_type == QuantType.per_1x32
)
gfx = get_gfx()
self.use_triton = gfx.startswith("gfx94") or (
gfx.startswith("gfx95") and envs.ATOM_USE_TRITON_GEMM
# Route MXFP4 MoE to Triton on gfx94/gfx12, or on gfx95 when CK is unavailable.
self.use_triton = gfx.startswith("gfx94") or gfx.startswith("gfx12") or (
gfx.startswith("gfx95") and (envs.ATOM_USE_TRITON_GEMM or not ENABLE_CK)
)
if self.use_triton:
from atom.model_ops.utils import has_triton_kernels
Expand Down
4 changes: 2 additions & 2 deletions atom/model_ops/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from functools import lru_cache

import torch
from aiter import mixed_sample_outer_exponential
from aiter.ops.triton.sample import mixed_sample_outer_exponential_triton
from aiter.ops.triton.softmax import softmax
from aiter.ops.triton.topk import topk
from torch import nn
Expand Down Expand Up @@ -99,7 +99,7 @@ def _temperature_sample(
exponential = get_per_token_exponential(vocab_size, logits.device).expand(
num_tokens, vocab_size
)
mixed_sample_outer_exponential(
mixed_sample_outer_exponential_triton(
sampled_tokens, logits, exponential, temperatures, eps=self.eps
)
return sampled_tokens
Expand Down
Loading
Loading