本项目实现了 Mamba 模型,作为一种前沿的高效序列建模方法,旨在突破 Transformer 在长序列处理中面临的二次方计算复杂度瓶颈。Mamba 的核心在于其选择性状态空间模型(Selective State Space Model, SSM),它通过输入依赖的动态权重更新来有效捕捉长距离依赖,同时保持线性计算复杂度。
Mamba 模型的实现主要参考了官方和社区的简化版本,并针对 PyTorch 环境进行了适配和优化。模型的定义主要位于 models/efficient/mamba.py 文件中,并依赖于 models/efficient/pscan.py 中实现的并行扫描操作。
MambaModel 类是整个 Mamba 架构的入口点,它负责整合嵌入层、多个 Mamba 层以及最终的分类/回归头。
import math
// ... existing code ...
class MambaModel(nn.Module):
def __init__(
self,
config: MambaConfig,
num_labels: int = 2,
task_type: str = "classification"
):
super().__init__()
self.config = config
self.task_type = task_type
self.num_labels = num_labels
self.class_weights = None
# 词嵌入层
self.embeddings = nn.Embedding(config.vocab_size, config.d_model)
# 添加token类型嵌入层
self.token_type_embeddings = nn.Embedding(2, config.d_model)
self.embedding_norm = nn.LayerNorm(config.d_model, eps=config.rms_norm_eps)
self.embedding_dropout = nn.Dropout(config.dropout_prob)
# Mamba层
self.layers = nn.ModuleList([ResidualBlock(config) for _ in range(config.n_layers)])
# 改进的分类/回归头
if task_type == "regression":
self.classifier = nn.Sequential(
nn.Linear(config.d_model, config.d_model // 2),
nn.ReLU(),
nn.Dropout(config.dropout_prob),
nn.Linear(config.d_model // 2, 1),
nn.Sigmoid() # 将输出映射到[0,1]范围
)
else:
// ... existing code ...
# 初始化权重
self.apply(self._init_weights)
def _init_weights(self, module):
// ... existing code ...
def forward(
self,
input_ids: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
token_type_ids: Optional[torch.Tensor] = None,
labels: Optional[torch.Tensor] = None
) -> Dict[str, torch.Tensor]:
// ... existing code ...
# 生成词嵌入
hidden_states = self.embeddings(input_ids)
# 添加token类型嵌入
if token_type_ids is not None:
token_type_embeddings = self.token_type_embeddings(token_type_ids)
hidden_states = hidden_states + token_type_embeddings
# 应用层归一化和dropout
hidden_states = self.embedding_norm(hidden_states)
hidden_states = self.embedding_dropout(hidden_states)
# 应用填充掩码
if attention_mask is not None:
hidden_states = hidden_states * attention_mask.unsqueeze(-1)
# Mamba层
for layer in self.layers:
hidden_states = layer(hidden_states)
# 改进的池化策略:使用平均池化
if attention_mask is not None:
// ... existing code ...
else:
pooled_output = hidden_states.mean(1)
// ... existing code ...
return outputs- 嵌入层 (L24-L27):
self.embeddings: 负责将输入的input_ids(token 整数 ID)转换为连续的词向量。self.token_type_embeddings: 类似于 BERT,用于区分输入序列的段落(如句子 A/B),为每个 token 分配类型嵌入。self.embedding_norm: 在嵌入后应用 RMS Normalization (L26-L26) 进行归一化,有助于稳定训练。self.embedding_dropout: 引入 Dropout (L27-L27) 进行正则化。
- Mamba 层 (
self.layers, L30-L30): 模型的核心,由多个ResidualBlock堆叠而成,每个ResidualBlock包含一个MambaBlock。 - 分类器/回归头 (L33-L49): 根据任务类型(分类或回归)配置不同的输出层。
- 分类任务使用一个包含 ReLU 激活和 Dropout 的多层感知机,最终输出
num_labels个 logits。 - 回归任务也使用多层感知机,但最终通过 Sigmoid 激活函数将输出映射到
[0,1]范围,然后可进一步缩放到特定回归目标范围(如[0,5])。
- 分类任务使用一个包含 ReLU 激活和 Dropout 的多层感知机,最终输出
- 权重初始化 (
_init_weights, L52-L61): 采用 Kaiming Normal 初始化线性层,正态分布初始化嵌入层,以及 RMSNorm 权重初始化为 1、偏置为 0 的策略。 - 前向传播 (
forward, L63-L121):- 输入
input_ids经过词嵌入和 token 类型嵌入,并进行归一化和 Dropout。 - 填充掩码 (L81-L82): 模型的填充掩码通过将
hidden_states乘以attention_mask.unsqueeze(-1)来实现,这使得填充位置的特征向量变为零,从而在后续计算中被忽略。 hidden_states顺序通过所有ResidualBlock层。- 池化策略 (L88-L94): Mamba 模型通常使用最后一个 token 的表示或平均池化来获取序列的整体表示。在此实现中,使用平均池化,并通过
attention_mask处理填充,确保只有非填充 token 参与平均。 - 池化后的输出传递给
classifier生成logits。 - 对于回归任务,
logits会进行额外的缩放。分类任务在推理时会计算probs。 - 损失计算(如果提供了
labels)也与基线模型一致,支持类别加权交叉熵或均方误差。
- 输入
ResidualBlock 是 Mamba 模型的中间层,它结合了一个 MambaBlock、一个归一化层和残差连接,遵循了深度学习中常用的残差网络(ResNet)结构,有助于训练更深的模型。
class ResidualBlock(nn.Module):
def __init__(self, config: MambaConfig):
super().__init__()
self.mixer = MambaBlock(config)
self.norm = RMSNorm(config.d_model, config.rms_norm_eps, config.mup)
def forward(self, x):
# x : (B, L, D)
# output : (B, L, D)
output = self.mixer(self.norm(x)) + x
return outputself.mixer(L137-L137): 一个MambaBlock实例,是处理序列的核心组件。self.norm(L138-L138): 一个RMSNorm实例,在将输入传递给MambaBlock之前进行归一化。这是一种替代 Layer Normalization 的归一化方法。- 前向传播 (
forward, L141-L144): 实现残差连接。输入x首先经过norm归一化,然后送入mixer(即MambaBlock)。mixer的输出与原始输入x相加,形成残差连接,有助于梯度的顺畅传播。
MambaBlock 是 Mamba 架构最核心的部分,实现了选择性状态空间模型。它将输入 x 映射到两个分支,一个分支进行短卷积和选择性扫描,另一个分支作为门控机制。
class MambaBlock(nn.Module):
def __init__(self, config: MambaConfig):
super().__init__()
self.config = config
# projects block input from D to 2*ED (two branches)
self.in_proj = nn.Linear(config.d_model, 2 * config.d_inner, bias=config.bias)
self.conv1d = nn.Conv1d(in_channels=config.d_inner, out_channels=config.d_inner,
kernel_size=config.d_conv, bias=config.conv_bias,
groups=config.d_inner,
padding=config.d_conv - 1)
# projects x to input-dependent delta, B, C
self.x_proj = nn.Linear(config.d_inner, config.dt_rank + 2 * config.d_state, bias=False)
# projects delta from dt_rank to d_inner
self.dt_proj = nn.Linear(config.dt_rank, config.d_inner, bias=True)
# dt initialization
// ... existing code ...
# S4D real initialization
A = torch.arange(1, config.d_state + 1, dtype=torch.float32).repeat(config.d_inner, 1)
self.A_log = nn.Parameter(torch.log(A)) # why store A in log ? to keep A < 0 (cf -torch.exp(...)) ? for gradient stability ?
self.A_log._no_weight_decay = True
self.D = nn.Parameter(torch.ones(config.d_inner))
self.D._no_weight_decay = True
# projects block output from ED back to D
self.out_proj = nn.Linear(config.d_inner, config.d_model, bias=config.bias)
if self.config.inner_layernorms:
self.dt_layernorm = RMSNorm(self.config.dt_rank, config.rms_norm_eps, config.mup)
self.B_layernorm = RMSNorm(self.config.d_state, config.rms_norm_eps, config.mup)
self.C_layernorm = RMSNorm(self.config.d_state, config.rms_norm_eps, config.mup)
else:
self.dt_layernorm = None;
self.B_layernorm = None;
self.C_layernorm = None;
if self.config.use_cuda:
try:
from mamba_ssm.ops.selective_scan_interface import selective_scan_fn
self.selective_scan_cuda = selective_scan_fn
except ImportError:
print("Failed to import mamba_ssm. Falling back to mamba.py.")
self.config.use_cuda = False;
def _apply_layernorms(self, dt, B, C):
// ... existing code ...
def forward(self, x):
# x : (B, L, D)
# y : (B, L, D)
_, L, _ = x.shape
xz = self.in_proj(x) # (B, L, 2*ED)
x, z = xz.chunk(2, dim=-1) # (B, L, ED), (B, L, ED)
# x branch
x = x.transpose(1, 2) # (B, ED, L)
x = self.conv1d(x)[:, :, :L] # depthwise convolution over time, with a short filter
x = x.transpose(1, 2) # (B, L, ED)
x = F.silu(x);
y = self.ssm(x, z);
if self.config.use_cuda:
output = self.out_proj(y) # (B, L, D)
return output # the rest of the operations are done in the ssm function (fused with the CUDA pscan)
# z branch
z = F.silu(z);
output = y * z;
output = self.out_proj(output) # (B, L, D)
return output
def ssm(self, x, z):
# x : (B, L, ED)
# y : (B, L, ED)
A = -torch.exp(self.A_log.float()) # (ED, N)
D = self.D.float();
deltaBC = self.x_proj(x) # (B, L, dt_rank+2*N)
delta, B, C = torch.split(deltaBC, [self.config.dt_rank, self.config.d_state, self.config.d_state], dim=-1) # (B, L, dt_rank), (B, L, N), (B, L, N)
delta, B, C = self._apply_layernorms(delta, B, C);
delta = self.dt_proj.weight @ delta.transpose(1, 2) # (ED, dt_rank) @ (B, L, dt_rank) -> (B, ED, L)
// ... existing code ...
return y- 输入投影 (
in_proj, L160-L160): 将输入x从d_model维度投影到2 * d_inner维度,并分割成两个分支x和z。x分支用于 SSM,z分支用于门控。 - 1D 因果卷积 (
conv1d, L162-L166): 对x分支应用一个短的 1D 深度可分离卷积(groups=config.d_inner),这是一种轻量级的局部特征提取方式。padding设置为d_conv-1以确保输出长度与输入相同,并实现因果性(不查看未来信息)。 x_proj(L169-L169): 将卷积后的x投影到dt_rank + 2 * d_state维度,用于生成delta、B和C参数。dt_proj(L172-L172): 将delta从dt_rank维度投影到d_inner维度。- 参数初始化 (L175-L208):
dt_proj的偏置 (L183-L188) 初始化为inv_dt,这是一个经过特殊计算的值,旨在控制delta的范围,使其在 softplus 激活后能够更好地工作。A_log(L191-L193) 和D(L195-L196) 是 SSM 的固定参数。A_log存储了对角线状态矩阵 A 的对数,以便在训练中保持 A 的负值,确保稳定性。
- 输出投影 (
out_proj, L200-L200): 将 SSM 的输出从d_inner维度投影回d_model维度。 - 内部层归一化 (
inner_layernorms, L203-L208, L214-L219): 可选地在delta、B和C上应用 RMS Normalization,进一步稳定这些关键参数的范围。 - CUDA 加速 (
use_cuda, L210-L213, L247-L255): 如果环境支持且配置启用,Mamba 可以调用mamba_ssm库中优化的 CUDA 核函数selective_scan_fn来加速选择性扫描操作,实现与官方实现相似的高性能。 - 前向传播 (
forward, L221-L268):- 将输入
x分割为x和z。 x经过 1D 卷积,然后通过F.silu激活。- 核心的 选择性扫描 (
self.ssm) 操作在x和z上执行。 - 如果使用 CUDA,
ssm的输出直接通过out_proj。 - 如果未使用 CUDA,
z分支也经过F.silu激活,然后与ssm的输出y相乘(门控机制),最后通过out_proj。
- 将输入
- 选择性扫描 (
ssm, L270-L308): 这是 Mamba 的核心,它根据输入动态计算状态更新参数。A参数通过self.A_log计算得到 (L275-L275)。deltaBC(L279-L279) 由x_proj生成,并被分割为delta、B和C。delta、B和C经过可选的内部层归一化 (L281-L281)。delta经过dt_proj线性变换 (L282-L282) 和 softplus 激活 (L295-L296)。- 根据
use_cuda配置,调用 CUDA 优化的selective_scan_cuda或 PyTorch 实现的selective_scan(并行) /selective_scan_seq(顺序) 版本。
选择性扫描是 SSM 的核心,它将线性的状态空间模型(State Space Model)转换为输入依赖的动态系统,从而能够捕捉长距离依赖。
4.1. pscan (Parallel Scan)
pscan 函数(定义在 models/efficient/pscan.py 中)实现了 Blelloch 版本的并行扫描操作。它是一个 torch.autograd.Function,意味着它提供了自定义的前向传播和反向传播逻辑,以确保计算的高效性和梯度正确性。
def npo2(len):
// ... existing code ...
def pad_npo2(X):
// ... existing code ...
class PScan(torch.autograd.Function):
@staticmethod
def pscan(A, X):
// ... existing code ...
@staticmethod
def pscan_rev(A, X):
// ... existing code ...
@staticmethod
def forward(ctx, A_in, X_in):
// ... existing code ...
PScan.pscan(A, X) // Core parallel scan operation
// ... existing code ...
@staticmethod
def backward(ctx, grad_output_in):
// ... existing code ...
PScan.pscan_rev(A, grad_output) // Reverse parallel scan for gradients
// ... existing code ...- 目的:
pscan的目标是高效地计算一系列递归关系H[t] = A[t] * H[t-1] + X[t](H[0] = 0)。传统的顺序计算需要O(L)步,而并行扫描可以在O(log L)步内完成,显著加速长序列处理。 - 实现:
PScan.pscan方法通过"向上扫描"(up sweep)和"向下扫描"(down sweep)两个阶段实现并行化。它通过重新排列和合并A和X矩阵来并行计算中间结果,最终得到所有时间步的H值。 - 反向传播:
PScan.backward方法利用pscan_rev(L125-L125) 实现反向并行扫描,以高效地计算A和X的梯度。
4.2. selective_scan 和 selective_scan_seq (在 mamba.py 中)
MambaBlock 中的 ssm 方法会根据配置选择调用 selective_scan (并行版本,基于 pscan) 或 selective_scan_seq (顺序版本,用于对比或调试)。
def selective_scan(self, x, delta, A, B, C, D):
// ... existing code ...
deltaA = torch.exp(delta.unsqueeze(-1) * A) # (B, L, ED, N)
deltaB = delta.unsqueeze(-1) * B.unsqueeze(2) # (B, L, ED, N)
BX = deltaB * (x.unsqueeze(-1)) # (B, L, ED, N)
hs = pscan(deltaA, BX) // 调用pScan进行并行扫描
y = (hs @ C.unsqueeze(-1)).squeeze(3) // (B, L, ED, N) @ (B, L, N, 1) -> (B, L, ED, 1)
y = y + D * x
return y
def selective_scan_seq(self, x, delta, A, B, C, D):
// ... existing code ...
deltaA = torch.exp(delta.unsqueeze(-1) * A) # (B, L, ED, N)
deltaB = delta.unsqueeze(-1) * B.unsqueeze(2) # (B, L, ED, N)
BX = deltaB * (x.unsqueeze(-1)) # (B, L, ED, N)
h = torch.zeros(x.size(0), self.config.d_inner, self.config.d_state, device=deltaA.device) # (B, ED, N)
hs = []
for t in range(0, L): // 顺序循环
h = deltaA[:, t] * h + BX[:, t]
hs.append(h)
hs = torch.stack(hs, dim=1) # (B, L, ED, N)
y = (hs @ C.unsqueeze(-1)).squeeze(3) // (B, L, ED, N) @ (B, L, N, 1) -> (B, L, ED, 1)
y = y + D * x
return y- 输入相关参数:
x是原始输入,delta、B、C是由输入依赖动态生成的。A和D是固定的模型参数。 deltaA和deltaB(L325-L326, L347-L348): 这些是核心的选择性机制。delta参数控制了A和B矩阵的更新,使得 SSM 能够根据输入内容动态地"选择"记住或忘记信息。BX(L328-L328, L350-L350):deltaB与输入x相乘,表示输入如何被添加到状态中。- 状态更新 (
hs = pscan(deltaA, BX)/ 循环 L355-L357):selective_scan通过调用pscan(L330-L330) 来并行计算所有时间步的隐状态h。selective_scan_seq则通过一个for循环 (L355-L357) 逐时间步地更新隐状态h(h = deltaA[:, t] * h + BX[:, t]),这模拟了 RNN 的行为,但计算效率较低。
- 输出 (
y, L332-L335, L360-L363): 最终的输出y是隐状态hs与C矩阵的乘积,加上残差连接D * x。这表示当前状态如何被投影到输出空间。
RMSNorm 是一种替代传统的 Layer Normalization 的归一化方法,它只对输入的均方根进行归一化,而不需要减去均值。
class RMSNorm(nn.Module):
def __init__(self, d_model: int, eps: float = 1e-5, use_mup: bool = False):
super().__init__()
self.use_mup = use_mup
self.eps = eps
// ... existing code ...
def forward(self, x):
output = x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
if not self.use_mup:
return output * self.weight
else:
return output- 原理 (L492-L492):
output = x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)。它计算输入x沿着最后一个维度的平方均值,然后取平方根的倒数,再乘以输入x。这有效地将输入的 L2 范数归一化到 1。 self.weight(L485-L485, L494-L495): 可选的仿射变换(即学习到的增益参数),通常在 RMSNorm 中使用,以恢复模型的表达能力。
Mamba 模型的一大优势是其在自回归推理时具有常数时间复杂度。通过在每个 MambaBlock 中维护一个固定大小的缓存,模型可以高效地生成序列。
def step(self, x, caches):
// ... existing code ...
for i, layer in enumerate(self.layers):
x, caches[i] = layer.step(x, caches[i])
return x, caches
class ResidualBlock(nn.Module):
// ... existing code ...
def step(self, x, cache):
// ... existing code ...
output, cache = self.mixer.step(self.norm(x), cache)
output = output + x
return output, cache
class MambaBlock(nn.Module):
// ... existing code ...
def step(self, x, cache):
// ... existing code ...
h, inputs = cache
xz = self.in_proj(x) # (B, 2*ED)
x, z = xz.chunk(2, dim=1) # (B, ED), (B, ED)
# x branch
x_cache = x.unsqueeze(2)
x = self.conv1d(torch.cat([inputs, x_cache], dim=2))[:, :, self.config.d_conv-1] # (B, ED)
x = F.silu(x);
y, h = self.ssm_step(x, h);
// ... existing code ...
# prepare cache for next call
inputs = torch.cat([inputs[:, :, 1:], x_cache], dim=2) # (B, ED, d_conv-1)
cache = (h, inputs)
return output, cache
def ssm_step(self, x, h):
// ... existing code ...
h = deltaA * h + BX # (B, ED, N)
y = (h @ C.unsqueeze(-1)).squeeze(2);
y = y + D * x;
return y, h- 缓存结构: 每个
MambaBlock在推理时需要维护两个关键的缓存信息:- 隐状态
h: 对应于 SSM 的循环状态,形状为(B, ED, N)。 - 最近的卷积输入
inputs: 对应于 1D 卷积的d_conv-1个历史输入,形状为(B, ED, d_conv-1)。由于d_conv通常很小(如 4),这个缓存的大小是固定的,不会随序列长度增长。
- 隐状态
MambaBlock.step()(L411-L450): 在自回归生成时,此方法接收单个时间步的输入x和上一时间步的缓存。它执行单步卷积和 SSM 计算,并返回当前时间步的输出和更新后的缓存,供下一时间步使用。inputs缓存通过移除最旧的输入并添加当前输入来更新。ssm_step()(L452-L476): 这是MambaBlock内部用于单步推理的 SSM 计算,它接收当前时间步的x和上一时间步的h,并更新h。
通过上述详细的模块设计和实现,Mamba 模型能够高效地处理长序列,并在保持线性复杂度的同时,有效地捕捉上下文信息,为超越 Transformer 的序列建模提供了新的方向。