Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 11 additions & 6 deletions atom/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -401,6 +401,16 @@ def _matches_exclude(
return True
return False

def apply_exclude_name_mapping(self, mapping: dict[str, str]):
if not mapping or not self.exclude_layers:
return
new_excludes = []
for name in self.exclude_layers:
for old, new in mapping.items():
name = name.replace(old, new)
new_excludes.append(name)
self.exclude_layers = list(dict.fromkeys(new_excludes))

def remap_layer_name(
self,
hf_config: PretrainedConfig,
Expand Down Expand Up @@ -467,12 +477,7 @@ def _remap_layer_name(name: str) -> list[str]:
# Models that have a mismatch between their HF quant config names and ATOM
# module paths declare `quant_exclude_name_mapping` as a class attribute.
if quant_exclude_name_mapping:
new_excludes = []
for name in self.exclude_layers:
for old, new in quant_exclude_name_mapping.items():
name = name.replace(old, new)
new_excludes.append(name)
self.exclude_layers = list(dict.fromkeys(new_excludes))
self.apply_exclude_name_mapping(quant_exclude_name_mapping)


_CONFIG_REGISTRY: dict[str, str] = {
Expand Down
4 changes: 4 additions & 0 deletions atom/model_ops/attention_mla.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,9 @@
from atom.model_ops.utils import get_and_maybe_dequant_weights
from atom.plugin import is_plugin_mode
from atom.plugin.attention_mla import MLAAttentionImplDecoratorForPluginMode
from atom.plugin.attention_mla_sparse import (
MLASparseAttentionImplDecoratorForPluginMode,
)
from atom.utils import envs
from atom.utils.decorators import mark_trace
from atom.utils.forward_context import (
Expand Down Expand Up @@ -106,6 +109,7 @@ def dynamic_per_batched_tensor_quant(
return x_scl_sat.to(dtype).contiguous(), scale.float().reciprocal()


@MLASparseAttentionImplDecoratorForPluginMode
@MLAAttentionImplDecoratorForPluginMode
class MLAAttention(nn.Module):
Comment thread
wuhuikx marked this conversation as resolved.
def __init__(
Expand Down
5 changes: 1 addition & 4 deletions atom/model_ops/paged_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,9 +92,6 @@ def __init__(
extra_impl_args["mla_modules"] = mla_modules

if use_mla:
assert (
mla_modules.indexer is None
), "MLAAttention is not supported for sparse mode"
self.num_heads = num_heads
self.v_head_dim = mla_modules.v_head_dim
self.qk_head_dim = mla_modules.qk_head_dim
Expand All @@ -114,7 +111,7 @@ def __init__(
quant_config=quant_config,
prefix=f"{prefix}.attn",
kv_b_proj=mla_modules.kv_b_proj,
use_sparse=False,
use_sparse=mla_modules.indexer is not None,
indexer=mla_modules.indexer,
**extra_impl_args,
)
Expand Down
10 changes: 9 additions & 1 deletion atom/models/deepseek_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,10 @@
from atom.utils.custom_register import direct_register_custom_op
from atom.utils.decorators import mark_trace, support_torch_compile
from atom.utils.forward_context import get_forward_context
from atom.plugin.attention_mla_sparse import (
IndexerDecoratorForPluginMode,
DeepseekV32IndexerCacheDecoratorForPluginMode,
)
from torch import nn
from transformers import PretrainedConfig

Expand Down Expand Up @@ -968,6 +972,7 @@ def yarn_get_mscale(scale: float = 1, mscale: float = 1) -> float:
return 0.1 * mscale * math.log(scale) + 1.0


@DeepseekV32IndexerCacheDecoratorForPluginMode
class DeepseekV32IndexerCache(nn.Module):

def __init__(
Expand Down Expand Up @@ -1144,6 +1149,7 @@ def sparse_attn_indexer_fake(
)


@IndexerDecoratorForPluginMode
class Indexer(nn.Module):

def __init__(
Expand Down Expand Up @@ -1207,6 +1213,8 @@ def __init__(
self.max_total_seq_len = atom_config.max_num_seqs * self.max_model_len
# register_metadata_builder("indexer_attn_metadata", self.k_cache.get_attn_backend().get_builder_cls())

self.sparse_attn_indexer_impl = torch.ops.aiter.sparse_attn_indexer

def forward(
self,
hidden_states: torch.Tensor,
Expand Down Expand Up @@ -1242,7 +1250,7 @@ def forward(
)
weights = weights.squeeze(-1)

return torch.ops.aiter.sparse_attn_indexer(
return self.sparse_attn_indexer_impl(
hidden_states,
self.k_cache.prefix,
self.k_cache.kv_cache[0],
Expand Down
Loading
Loading