Skip to content

Latest commit

 

History

History
442 lines (334 loc) · 21.6 KB

File metadata and controls

442 lines (334 loc) · 21.6 KB

Mamba 模型架构设计与实现

本项目实现了 Mamba 模型,作为一种前沿的高效序列建模方法,旨在突破 Transformer 在长序列处理中面临的二次方计算复杂度瓶颈。Mamba 的核心在于其选择性状态空间模型(Selective State Space Model, SSM),它通过输入依赖的动态权重更新来有效捕捉长距离依赖,同时保持线性计算复杂度。

Mamba 模型的实现主要参考了官方和社区的简化版本,并针对 PyTorch 环境进行了适配和优化。模型的定义主要位于 models/efficient/mamba.py 文件中,并依赖于 models/efficient/pscan.py 中实现的并行扫描操作。

1. Mamba 模型 (MambaModel) 顶层结构

MambaModel 类是整个 Mamba 架构的入口点,它负责整合嵌入层、多个 Mamba 层以及最终的分类/回归头。

import math
// ... existing code ...

class MambaModel(nn.Module):
    def __init__(
        self,
        config: MambaConfig,
        num_labels: int = 2,
        task_type: str = "classification"
    ):
        super().__init__()
        self.config = config
        self.task_type = task_type
        self.num_labels = num_labels
        self.class_weights = None

        # 词嵌入层
        self.embeddings = nn.Embedding(config.vocab_size, config.d_model)
        # 添加token类型嵌入层
        self.token_type_embeddings = nn.Embedding(2, config.d_model)
        self.embedding_norm = nn.LayerNorm(config.d_model, eps=config.rms_norm_eps)
        self.embedding_dropout = nn.Dropout(config.dropout_prob)

        # Mamba层
        self.layers = nn.ModuleList([ResidualBlock(config) for _ in range(config.n_layers)])
        
        # 改进的分类/回归头
        if task_type == "regression":
            self.classifier = nn.Sequential(
                nn.Linear(config.d_model, config.d_model // 2),
                nn.ReLU(),
                nn.Dropout(config.dropout_prob),
                nn.Linear(config.d_model // 2, 1),
                nn.Sigmoid()  # 将输出映射到[0,1]范围
            )
        else:
            // ... existing code ...
        
        # 初始化权重
        self.apply(self._init_weights)

    def _init_weights(self, module):
        // ... existing code ...

    def forward(
        self,
        input_ids: torch.Tensor,
        attention_mask: Optional[torch.Tensor] = None,
        token_type_ids: Optional[torch.Tensor] = None,
        labels: Optional[torch.Tensor] = None
    ) -> Dict[str, torch.Tensor]:
        // ... existing code ...
        
        # 生成词嵌入
        hidden_states = self.embeddings(input_ids)
        
        # 添加token类型嵌入
        if token_type_ids is not None:
            token_type_embeddings = self.token_type_embeddings(token_type_ids)
            hidden_states = hidden_states + token_type_embeddings
        
        # 应用层归一化和dropout
        hidden_states = self.embedding_norm(hidden_states)
        hidden_states = self.embedding_dropout(hidden_states)
        
        # 应用填充掩码
        if attention_mask is not None:
            hidden_states = hidden_states * attention_mask.unsqueeze(-1)
        
        # Mamba层
        for layer in self.layers:
            hidden_states = layer(hidden_states)
        
        # 改进的池化策略:使用平均池化
        if attention_mask is not None:
            // ... existing code ...
        else:
            pooled_output = hidden_states.mean(1)
        
        // ... existing code ...
        
        return outputs
  • 嵌入层 (L24-L27):
    • self.embeddings: 负责将输入的 input_ids(token 整数 ID)转换为连续的词向量。
    • self.token_type_embeddings: 类似于 BERT,用于区分输入序列的段落(如句子 A/B),为每个 token 分配类型嵌入。
    • self.embedding_norm: 在嵌入后应用 RMS Normalization (L26-L26) 进行归一化,有助于稳定训练。
    • self.embedding_dropout: 引入 Dropout (L27-L27) 进行正则化。
  • Mamba 层 (self.layers, L30-L30): 模型的核心,由多个 ResidualBlock 堆叠而成,每个 ResidualBlock 包含一个 MambaBlock
  • 分类器/回归头 (L33-L49): 根据任务类型(分类或回归)配置不同的输出层。
    • 分类任务使用一个包含 ReLU 激活和 Dropout 的多层感知机,最终输出 num_labels 个 logits。
    • 回归任务也使用多层感知机,但最终通过 Sigmoid 激活函数将输出映射到 [0,1] 范围,然后可进一步缩放到特定回归目标范围(如 [0,5])。
  • 权重初始化 (_init_weights, L52-L61): 采用 Kaiming Normal 初始化线性层,正态分布初始化嵌入层,以及 RMSNorm 权重初始化为 1、偏置为 0 的策略。
  • 前向传播 (forward, L63-L121):
    • 输入 input_ids 经过词嵌入和 token 类型嵌入,并进行归一化和 Dropout。
    • 填充掩码 (L81-L82): 模型的填充掩码通过将 hidden_states 乘以 attention_mask.unsqueeze(-1) 来实现,这使得填充位置的特征向量变为零,从而在后续计算中被忽略。
    • hidden_states 顺序通过所有 ResidualBlock 层。
    • 池化策略 (L88-L94): Mamba 模型通常使用最后一个 token 的表示或平均池化来获取序列的整体表示。在此实现中,使用平均池化,并通过 attention_mask 处理填充,确保只有非填充 token 参与平均。
    • 池化后的输出传递给 classifier 生成 logits
    • 对于回归任务,logits 会进行额外的缩放。分类任务在推理时会计算 probs
    • 损失计算(如果提供了 labels)也与基线模型一致,支持类别加权交叉熵或均方误差。

2. 残差块 (ResidualBlock)

ResidualBlock 是 Mamba 模型的中间层,它结合了一个 MambaBlock、一个归一化层和残差连接,遵循了深度学习中常用的残差网络(ResNet)结构,有助于训练更深的模型。

class ResidualBlock(nn.Module):
    def __init__(self, config: MambaConfig):
        super().__init__()

        self.mixer = MambaBlock(config)
        self.norm = RMSNorm(config.d_model, config.rms_norm_eps, config.mup)

    def forward(self, x):
        # x : (B, L, D)
        # output : (B, L, D)
        output = self.mixer(self.norm(x)) + x
        return output
  • self.mixer (L137-L137): 一个 MambaBlock 实例,是处理序列的核心组件。
  • self.norm (L138-L138): 一个 RMSNorm 实例,在将输入传递给 MambaBlock 之前进行归一化。这是一种替代 Layer Normalization 的归一化方法。
  • 前向传播 (forward, L141-L144): 实现残差连接。输入 x 首先经过 norm 归一化,然后送入 mixer (即 MambaBlock)。mixer 的输出与原始输入 x 相加,形成残差连接,有助于梯度的顺畅传播。

3. Mamba 核心块 (MambaBlock)

MambaBlock 是 Mamba 架构最核心的部分,实现了选择性状态空间模型。它将输入 x 映射到两个分支,一个分支进行短卷积和选择性扫描,另一个分支作为门控机制。

class MambaBlock(nn.Module):
    def __init__(self, config: MambaConfig):
        super().__init__()

        self.config = config

        # projects block input from D to 2*ED (two branches)
        self.in_proj = nn.Linear(config.d_model, 2 * config.d_inner, bias=config.bias)

        self.conv1d = nn.Conv1d(in_channels=config.d_inner, out_channels=config.d_inner, 
                              kernel_size=config.d_conv, bias=config.conv_bias, 
                              groups=config.d_inner,
                              padding=config.d_conv - 1)
        
        # projects x to input-dependent delta, B, C
        self.x_proj = nn.Linear(config.d_inner, config.dt_rank + 2 * config.d_state, bias=False)

        # projects delta from dt_rank to d_inner
        self.dt_proj = nn.Linear(config.dt_rank, config.d_inner, bias=True)

        # dt initialization
        // ... existing code ...

        # S4D real initialization
        A = torch.arange(1, config.d_state + 1, dtype=torch.float32).repeat(config.d_inner, 1)
        self.A_log = nn.Parameter(torch.log(A)) # why store A in log ? to keep A < 0 (cf -torch.exp(...)) ? for gradient stability ?
        self.A_log._no_weight_decay = True

        self.D = nn.Parameter(torch.ones(config.d_inner))
        self.D._no_weight_decay = True

        # projects block output from ED back to D
        self.out_proj = nn.Linear(config.d_inner, config.d_model, bias=config.bias)

        if self.config.inner_layernorms:
            self.dt_layernorm = RMSNorm(self.config.dt_rank, config.rms_norm_eps, config.mup)
            self.B_layernorm = RMSNorm(self.config.d_state, config.rms_norm_eps, config.mup)
            self.C_layernorm = RMSNorm(self.config.d_state, config.rms_norm_eps, config.mup)
        else:
            self.dt_layernorm = None;
            self.B_layernorm = None;
            self.C_layernorm = None;

        if self.config.use_cuda:
            try:
                from mamba_ssm.ops.selective_scan_interface import selective_scan_fn
                self.selective_scan_cuda = selective_scan_fn
            except ImportError:
                print("Failed to import mamba_ssm. Falling back to mamba.py.")
                self.config.use_cuda = False;

    def _apply_layernorms(self, dt, B, C):
        // ... existing code ...

    def forward(self, x):
        # x : (B, L, D)
        # y : (B, L, D)

        _, L, _ = x.shape

        xz = self.in_proj(x) # (B, L, 2*ED)
        x, z = xz.chunk(2, dim=-1) # (B, L, ED), (B, L, ED)

        # x branch
        x = x.transpose(1, 2) # (B, ED, L)
        x = self.conv1d(x)[:, :, :L] # depthwise convolution over time, with a short filter
        x = x.transpose(1, 2) # (B, L, ED)

        x = F.silu(x);
        y = self.ssm(x, z);

        if self.config.use_cuda:
            output = self.out_proj(y) # (B, L, D)
            return output # the rest of the operations are done in the ssm function (fused with the CUDA pscan)

        # z branch
        z = F.silu(z);

        output = y * z;
        output = self.out_proj(output) # (B, L, D)

        return output
    
    def ssm(self, x, z):
        # x : (B, L, ED)

        # y : (B, L, ED)

        A = -torch.exp(self.A_log.float()) # (ED, N)
        D = self.D.float();

        deltaBC = self.x_proj(x) # (B, L, dt_rank+2*N)
        delta, B, C = torch.split(deltaBC, [self.config.dt_rank, self.config.d_state, self.config.d_state], dim=-1) # (B, L, dt_rank), (B, L, N), (B, L, N)
        delta, B, C = self._apply_layernorms(delta, B, C);
        delta = self.dt_proj.weight @ delta.transpose(1, 2) # (ED, dt_rank) @ (B, L, dt_rank) -> (B, ED, L)
        
        // ... existing code ...
        
        return y
  • 输入投影 (in_proj, L160-L160): 将输入 xd_model 维度投影到 2 * d_inner 维度,并分割成两个分支 xzx 分支用于 SSM,z 分支用于门控。
  • 1D 因果卷积 (conv1d, L162-L166): 对 x 分支应用一个短的 1D 深度可分离卷积(groups=config.d_inner),这是一种轻量级的局部特征提取方式。padding 设置为 d_conv-1 以确保输出长度与输入相同,并实现因果性(不查看未来信息)。
  • x_proj (L169-L169): 将卷积后的 x 投影到 dt_rank + 2 * d_state 维度,用于生成 deltaBC 参数。
  • dt_proj (L172-L172): 将 deltadt_rank 维度投影到 d_inner 维度。
  • 参数初始化 (L175-L208):
    • dt_proj 的偏置 (L183-L188) 初始化为 inv_dt,这是一个经过特殊计算的值,旨在控制 delta 的范围,使其在 softplus 激活后能够更好地工作。
    • A_log (L191-L193) 和 D (L195-L196) 是 SSM 的固定参数。A_log 存储了对角线状态矩阵 A 的对数,以便在训练中保持 A 的负值,确保稳定性。
  • 输出投影 (out_proj, L200-L200): 将 SSM 的输出从 d_inner 维度投影回 d_model 维度。
  • 内部层归一化 (inner_layernorms, L203-L208, L214-L219): 可选地在 deltaBC 上应用 RMS Normalization,进一步稳定这些关键参数的范围。
  • CUDA 加速 (use_cuda, L210-L213, L247-L255): 如果环境支持且配置启用,Mamba 可以调用 mamba_ssm 库中优化的 CUDA 核函数 selective_scan_fn 来加速选择性扫描操作,实现与官方实现相似的高性能。
  • 前向传播 (forward, L221-L268):
    • 将输入 x 分割为 xz
    • x 经过 1D 卷积,然后通过 F.silu 激活。
    • 核心的 选择性扫描 (self.ssm) 操作在 xz 上执行。
    • 如果使用 CUDA,ssm 的输出直接通过 out_proj
    • 如果未使用 CUDA,z 分支也经过 F.silu 激活,然后与 ssm 的输出 y 相乘(门控机制),最后通过 out_proj
  • 选择性扫描 (ssm, L270-L308): 这是 Mamba 的核心,它根据输入动态计算状态更新参数。
    • A 参数通过 self.A_log 计算得到 (L275-L275)。
    • deltaBC (L279-L279) 由 x_proj 生成,并被分割为 deltaBC
    • deltaBC 经过可选的内部层归一化 (L281-L281)。
    • delta 经过 dt_proj 线性变换 (L282-L282) 和 softplus 激活 (L295-L296)。
    • 根据 use_cuda 配置,调用 CUDA 优化的 selective_scan_cuda 或 PyTorch 实现的 selective_scan (并行) / selective_scan_seq (顺序) 版本。

4. 选择性扫描操作 (pscan.pymamba.py)

选择性扫描是 SSM 的核心,它将线性的状态空间模型(State Space Model)转换为输入依赖的动态系统,从而能够捕捉长距离依赖。

4.1. pscan (Parallel Scan)

pscan 函数(定义在 models/efficient/pscan.py 中)实现了 Blelloch 版本的并行扫描操作。它是一个 torch.autograd.Function,意味着它提供了自定义的前向传播和反向传播逻辑,以确保计算的高效性和梯度正确性。

def npo2(len):
    // ... existing code ...

def pad_npo2(X):
    // ... existing code ...

class PScan(torch.autograd.Function):
    @staticmethod
    def pscan(A, X):
        // ... existing code ...
    
    @staticmethod
    def pscan_rev(A, X):
        // ... existing code ...
    
    @staticmethod
    def forward(ctx, A_in, X_in):
        // ... existing code ...
        PScan.pscan(A, X) // Core parallel scan operation
        // ... existing code ...
    
    @staticmethod
    def backward(ctx, grad_output_in):
        // ... existing code ...
        PScan.pscan_rev(A, grad_output) // Reverse parallel scan for gradients
        // ... existing code ...
  • 目的: pscan 的目标是高效地计算一系列递归关系 H[t] = A[t] * H[t-1] + X[t] (H[0] = 0)。传统的顺序计算需要 O(L) 步,而并行扫描可以在 O(log L) 步内完成,显著加速长序列处理。
  • 实现: PScan.pscan 方法通过"向上扫描"(up sweep)和"向下扫描"(down sweep)两个阶段实现并行化。它通过重新排列和合并 AX 矩阵来并行计算中间结果,最终得到所有时间步的 H 值。
  • 反向传播: PScan.backward 方法利用 pscan_rev (L125-L125) 实现反向并行扫描,以高效地计算 AX 的梯度。

4.2. selective_scanselective_scan_seq (在 mamba.py 中)

MambaBlock 中的 ssm 方法会根据配置选择调用 selective_scan (并行版本,基于 pscan) 或 selective_scan_seq (顺序版本,用于对比或调试)。

    def selective_scan(self, x, delta, A, B, C, D):
        // ... existing code ...
        deltaA = torch.exp(delta.unsqueeze(-1) * A) # (B, L, ED, N)
        deltaB = delta.unsqueeze(-1) * B.unsqueeze(2) # (B, L, ED, N)

        BX = deltaB * (x.unsqueeze(-1)) # (B, L, ED, N)
        
        hs = pscan(deltaA, BX) // 调用pScan进行并行扫描

        y = (hs @ C.unsqueeze(-1)).squeeze(3) // (B, L, ED, N) @ (B, L, N, 1) -> (B, L, ED, 1)

        y = y + D * x

        return y
    
    def selective_scan_seq(self, x, delta, A, B, C, D):
        // ... existing code ...
        deltaA = torch.exp(delta.unsqueeze(-1) * A) # (B, L, ED, N)
        deltaB = delta.unsqueeze(-1) * B.unsqueeze(2) # (B, L, ED, N)

        BX = deltaB * (x.unsqueeze(-1)) # (B, L, ED, N)

        h = torch.zeros(x.size(0), self.config.d_inner, self.config.d_state, device=deltaA.device) # (B, ED, N)
        hs = []

        for t in range(0, L): // 顺序循环
            h = deltaA[:, t] * h + BX[:, t]
            hs.append(h)
            
        hs = torch.stack(hs, dim=1) # (B, L, ED, N)

        y = (hs @ C.unsqueeze(-1)).squeeze(3) // (B, L, ED, N) @ (B, L, N, 1) -> (B, L, ED, 1)

        y = y + D * x

        return y
  • 输入相关参数x 是原始输入,deltaBC 是由输入依赖动态生成的。AD 是固定的模型参数。
  • deltaAdeltaB (L325-L326, L347-L348): 这些是核心的选择性机制。delta 参数控制了 AB 矩阵的更新,使得 SSM 能够根据输入内容动态地"选择"记住或忘记信息。
  • BX (L328-L328, L350-L350): deltaB 与输入 x 相乘,表示输入如何被添加到状态中。
  • 状态更新 (hs = pscan(deltaA, BX) / 循环 L355-L357):
    • selective_scan 通过调用 pscan (L330-L330) 来并行计算所有时间步的隐状态 h
    • selective_scan_seq 则通过一个 for 循环 (L355-L357) 逐时间步地更新隐状态 h (h = deltaA[:, t] * h + BX[:, t]),这模拟了 RNN 的行为,但计算效率较低。
  • 输出 (y, L332-L335, L360-L363): 最终的输出 y 是隐状态 hsC 矩阵的乘积,加上残差连接 D * x。这表示当前状态如何被投影到输出空间。

5. RMS 归一化 (RMSNorm)

RMSNorm 是一种替代传统的 Layer Normalization 的归一化方法,它只对输入的均方根进行归一化,而不需要减去均值。

class RMSNorm(nn.Module):
    def __init__(self, d_model: int, eps: float = 1e-5, use_mup: bool = False):
        super().__init__()

        self.use_mup = use_mup
        self.eps = eps

        // ... existing code ...

    def forward(self, x):
        output = x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)

        if not self.use_mup:
            return output * self.weight
        else:
            return output
  • 原理 (L492-L492): output = x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)。它计算输入 x 沿着最后一个维度的平方均值,然后取平方根的倒数,再乘以输入 x。这有效地将输入的 L2 范数归一化到 1。
  • self.weight (L485-L485, L494-L495): 可选的仿射变换(即学习到的增益参数),通常在 RMSNorm 中使用,以恢复模型的表达能力。

6. 推理阶段 (step 方法)

Mamba 模型的一大优势是其在自回归推理时具有常数时间复杂度。通过在每个 MambaBlock 中维护一个固定大小的缓存,模型可以高效地生成序列。

    def step(self, x, caches):
        // ... existing code ...
        for i, layer in enumerate(self.layers):
            x, caches[i] = layer.step(x, caches[i])
        return x, caches

class ResidualBlock(nn.Module):
    // ... existing code ...
    def step(self, x, cache):
        // ... existing code ...
        output, cache = self.mixer.step(self.norm(x), cache)
        output = output + x
        return output, cache

class MambaBlock(nn.Module):
    // ... existing code ...
    def step(self, x, cache):
        // ... existing code ...
        h, inputs = cache
        
        xz = self.in_proj(x) # (B, 2*ED)
        x, z = xz.chunk(2, dim=1) # (B, ED), (B, ED)

        # x branch
        x_cache = x.unsqueeze(2)
        x = self.conv1d(torch.cat([inputs, x_cache], dim=2))[:, :, self.config.d_conv-1] # (B, ED)

        x = F.silu(x);
        y, h = self.ssm_step(x, h);

        // ... existing code ...

        # prepare cache for next call
        inputs = torch.cat([inputs[:, :, 1:], x_cache], dim=2) # (B, ED, d_conv-1)
        cache = (h, inputs)
        
        return output, cache

    def ssm_step(self, x, h):
        // ... existing code ...
        h = deltaA * h + BX # (B, ED, N)
        y = (h @ C.unsqueeze(-1)).squeeze(2);
        y = y + D * x;
        return y, h
  • 缓存结构: 每个 MambaBlock 在推理时需要维护两个关键的缓存信息:
    • 隐状态 h: 对应于 SSM 的循环状态,形状为 (B, ED, N)
    • 最近的卷积输入 inputs: 对应于 1D 卷积的 d_conv-1 个历史输入,形状为 (B, ED, d_conv-1)。由于 d_conv 通常很小(如 4),这个缓存的大小是固定的,不会随序列长度增长。
  • MambaBlock.step() (L411-L450): 在自回归生成时,此方法接收单个时间步的输入 x 和上一时间步的缓存。它执行单步卷积和 SSM 计算,并返回当前时间步的输出和更新后的缓存,供下一时间步使用。inputs 缓存通过移除最旧的输入并添加当前输入来更新。
  • ssm_step() (L452-L476): 这是 MambaBlock 内部用于单步推理的 SSM 计算,它接收当前时间步的 x 和上一时间步的 h,并更新 h

通过上述详细的模块设计和实现,Mamba 模型能够高效地处理长序列,并在保持线性复杂度的同时,有效地捕捉上下文信息,为超越 Transformer 的序列建模提供了新的方向。