pydata: Huiming's learning notes

Keep Looking, Don't Settle

DeepSeek-V3.2 Lightning Indexer

Deepseek Lighting Indexer plot

背景

在之前的文章 DeepSeek V3 Multi-head Latent Attention (MLA) 中,我们详细介绍了 DeepSeek-V3 的 MLA 机制如何通过低秩压缩减少 KV Cache 的内存占用。MLA 解决了推理时的内存带宽问题,但 attention 的计算复杂度仍然是 \(O(L^2)\),随着序列长度 \(L\) 增长(如 128K tokens),计算量依然巨大。

DeepSeek-V3.2 引入了 DeepSeek Sparse Attention (DSA),其核心思想是:不需要让每个 query token 关注所有 preceding tokens,只需关注最相关的 top-k 个即可。这就需要一个快速评估机制来决定哪些 token 值得关注——这就是 Lightning Indexer 的作用。

本文基于 DeepSeek-V3.2 论文和开源实现,详细解析 Lightning Indexer 的工作原理。


1. DSA 的整体设计思路

DSA 由两个组件构成:

  1. Lightning Indexer:以极低计算代价为每个 query token 打出一个"相关性分数" \(I_{t,s}\),决定应该关注哪些 preceding tokens
  2. Fine-grained Token Selection:根据 indexer 的分数,只选择 top-k 个 token 做完整的 MLA attention

这样,主 attention 的复杂度从 \(O(L^2)\) 降低到 \(O(L \cdot k)\),其中 \(k = 2048 \ll L = 128K\)

虽然 Lightning Indexer 本身仍是 \(O(L^2)\),但由于它使用 FP8 计算、MQA 设计、ReLU 激活,实际计算量远小于完整的 MLA attention。


2. 符号定义

首先定义本文使用的维度符号,对应 config_671B_v3.2.json 中的配置:

$$ \begin{aligned} & d = 7168 \quad \text{(模型隐藏维度, dim)} \\ & n_h = 128 \quad \text{(MLA attention head数, n\_heads)} \\ & d_c = 512 \quad \text{(KV LoRA rank, kv\_lora\_rank)} \\ & d_c' = 1536 \quad \text{(Query LoRA rank, q\_lora\_rank)} \\ & d_h = 128 \quad \text{(qk\_nope\_head\_dim)} \\ & d_h^R = 64 \quad \text{(qk\_rope\_head\_dim)} \\ & H_I = 64 \quad \text{(indexer head数, index\_n\_heads)} \\ & d_I = 128 \quad \text{(indexer head维度, index\_head\_dim)} \\ & k = 2048 \quad \text{(top-k选择数, index\_topk)} \end{aligned} $$

3. Lightning Indexer 核心公式

3.1 Index Score 计算 (论文 Equation 1)

对于 query token 位置 \(t\) 和 preceding token 位置 \(s\),indexer 计算一个标量分数:

$$I_{t,s} = \sum_{j=1}^{H_I} w^I_{t,j} \cdot \text{ReLU}\left(\mathbf{q}^I_{t,j} \cdot \mathbf{k}^I_s\right) \tag{1}$$

其中:

  • \(H_I = 64\):indexer 的 head 数量
  • \(\mathbf{q}^I_{t,j} \in \mathbb{R}^{d_I}\):位置 \(t\) 在第 \(j\) 个 indexer head 的 query 向量,\(d_I = 128\)
  • \(\mathbf{k}^I_s \in \mathbb{R}^{d_I}\):位置 \(s\) 的 indexer key 向量,所有 head 共享(Multi-Query 设计)
  • \(w^I_{t,j} \in \mathbb{R}\):位置 \(t\) 在第 \(j\) 个 head 的标量权重

直觉理解:这是一个加权多头线性注意力。每个 head 计算 query-key 的点积,用 ReLU 截断负值(只保留正相关),然后通过可学习的权重 \(w^I_{t,j}\) 汇总所有 head 的结果。

3.2 Top-k 选择 (论文 Equation 2)

根据分数选出最相关的 \(k\) 个 token,只对这些 token 做完整 MLA attention:

$$\mathbf{u}_t = \text{Attn}\left(\mathbf{h}_t, \left\{\mathbf{c}_s \mid I_{t,s} \in \text{Top-k}(I_{t,:})\right\}\right) \tag{2}$$

其中 \(\mathbf{c}_s\) 是 MLA 中 token \(s\) 的 latent KV 表示 \(\mathbf{c}_s^{KV} \in \mathbb{R}^{512}\)


4. 各矩阵的详细变换过程

4.1 Indexer Query 的计算

Lightning Indexer 的 query 复用了 MLA 的 query 压缩中间表示 \(\mathbf{c}_t^Q\)(即 qr),避免重复计算:

Step 1: MLA query 压缩(与 MLA 共享)

$$\mathbf{c}_t^Q = \text{RMSNorm}(\mathbf{W}^{DQ} \mathbf{h}_t) \in \mathbb{R}^{d_c'}$$

其中 \(\mathbf{W}^{DQ} \in \mathbb{R}^{d_c' \times d} = \mathbb{R}^{1536 \times 7168}\),这一步与 MLA 共享。

Step 2: Indexer query 投影

$$\mathbf{q}^I_t = \mathbf{W}^{I,UQ} \mathbf{c}_t^Q \in \mathbb{R}^{H_I \cdot d_I}$$

其中 \(\mathbf{W}^{I,UQ} \in \mathbb{R}^{(H_I \cdot d_I) \times d_c'} = \mathbb{R}^{8192 \times 1536}\)

Reshape 后得到每个 head 的 query 向量:

$$\mathbf{q}^I_t \to \{\mathbf{q}^I_{t,j}\}_{j=1}^{64}, \quad \mathbf{q}^I_{t,j} \in \mathbb{R}^{128}$$

Step 3: 部分施加 RoPE

对每个 head 的 128 维向量,从前面取出 \(d_h^R = 64\) 维施加 RoPE:

$$\mathbf{q}^I_{t,j} = [\underbrace{\mathbf{q}^{I,pe}_{t,j}}_{64d} \;|\; \underbrace{\mathbf{q}^{I,nope}_{t,j}}_{64d}]$$
$$\mathbf{q}^I_{t,j} \leftarrow [\text{RoPE}(\mathbf{q}^{I,pe}_{t,j}) \;;\; \mathbf{q}^{I,nope}_{t,j}] \in \mathbb{R}^{128}$$

Step 4: Hadamard Rotation + FP8 量化

$$\mathbf{q}^I_{t,j} \leftarrow \text{Quantize}_{FP8}(\mathbf{H} \cdot \mathbf{q}^I_{t,j})$$

Hadamard Rotation 将 pe 和 nope 维度混合,使数值分布更均匀,有利于 FP8 量化精度。

矩阵维度汇总

步骤 输入 变换矩阵 输出
压缩(共享) \(\mathbf{h}_t \in \mathbb{R}^{7168}\) \(\mathbf{W}^{DQ} \in \mathbb{R}^{1536 \times 7168}\) \(\mathbf{c}_t^Q \in \mathbb{R}^{1536}\)
投影 \(\mathbf{c}_t^Q \in \mathbb{R}^{1536}\) \(\mathbf{W}^{I,UQ} \in \mathbb{R}^{8192 \times 1536}\) \(\mathbf{q}^I_t \in \mathbb{R}^{8192}\)
reshape \(\mathbb{R}^{8192}\) \(\mathbb{R}^{64 \times 128}\)
RoPE 每 head 前 64d 旋转矩阵 每 head 前 64d
Hadamard \(\mathbb{R}^{64 \times 128}\) \(\mathbf{H} \in \mathbb{R}^{128 \times 128}\) \(\mathbb{R}^{64 \times 128}\)
FP8 量化 float → fp8 \(\mathbb{R}^{64 \times 128}\) (FP8)

4.2 Indexer Key 的计算

Key 直接从输入隐藏状态投影,所有 64 个 head 共享同一个 key(Multi-Query Attention 风格):

Step 1: 线性投影 + LayerNorm

$$\mathbf{k}^I_s = \text{LayerNorm}(\mathbf{W}^{I,K} \mathbf{h}_s) \in \mathbb{R}^{d_I}$$

其中 \(\mathbf{W}^{I,K} \in \mathbb{R}^{d_I \times d} = \mathbb{R}^{128 \times 7168}\)

Step 2: 部分施加 RoPE

与 query 相同,前 64 维施加 RoPE:

$$\mathbf{k}^I_s = [\underbrace{\mathbf{k}^{I,pe}_{s}}_{64d} \;|\; \underbrace{\mathbf{k}^{I,nope}_{s}}_{64d}]$$
$$\mathbf{k}^I_s \leftarrow [\text{RoPE}(\mathbf{k}^{I,pe}_{s}) \;;\; \mathbf{k}^{I,nope}_{s}] \in \mathbb{R}^{128}$$

Step 3: Hadamard Rotation + FP8 量化

$$\mathbf{k}^I_s \leftarrow \text{Quantize}_{FP8}(\mathbf{H} \cdot \mathbf{k}^I_s)$$

矩阵维度汇总

步骤 输入 变换矩阵 输出
投影 \(\mathbf{h}_s \in \mathbb{R}^{7168}\) \(\mathbf{W}^{I,K} \in \mathbb{R}^{128 \times 7168}\) \(\mathbf{k}^I_s \in \mathbb{R}^{128}\)
LayerNorm \(\mathbb{R}^{128}\) \(\mathbb{R}^{128}\)
RoPE 前 64d 旋转矩阵 前 64d
Hadamard \(\mathbb{R}^{128}\) \(\mathbf{H} \in \mathbb{R}^{128 \times 128}\) \(\mathbb{R}^{128}\)
FP8 量化 float → fp8 \(\mathbb{R}^{128}\) (FP8)

关键设计选择:Key 不分 head,所有 64 个 query head 共享同一个 key 向量。这意味着 indexer 的 KV cache 只需每个 token 存 128 维(FP8)= 128 bytes,远小于 MLA 的 KV cache(\(c_t^{KV}\) 512 维 + \(k_t^R\) 64 维)。

4.3 Head 权重的计算

每个 head 有一个可学习的标量权重 \(w^I_{t,j}\),用于控制各 head 对最终分数的贡献:

$$\mathbf{w}^I_t = \frac{1}{\sqrt{H_I}} \cdot \mathbf{W}^{I,w} \mathbf{h}_t \in \mathbb{R}^{H_I}$$

其中 \(\mathbf{W}^{I,w} \in \mathbb{R}^{H_I \times d} = \mathbb{R}^{64 \times 7168}\)

Note: 这里 \(\mathbf{w}^I_t\) 直接从 \(\mathbf{h}_t\) 投影得到(float32 精度),而不是从压缩表示 \(\mathbf{c}_t^Q\) 得到。这说明权重需要保留更完整的隐藏状态信息。

4.4 完整的 Score 计算流程

将上述所有部分组合,计算 \(I_{t,s}\)

$$I_{t,s} = \sum_{j=1}^{64} w^I_{t,j} \cdot \text{ReLU}\left(\mathbf{q}^{I}_{t,j} \cdot \mathbf{k}^{I}_s\right)$$

在实际 CUDA kernel 实现中,由于使用了 FP8 block quantization,计算被重组为:

# weights 与 q 的 scale 融合
weights = (W_weights @ h_t) / sqrt(H_I)       # (B, L, 64)
weights = weights.unsqueeze(-1) * q_scale * softmax_scale
# shape: (B, L, 64, head_dim // block_size)

# fp8_index kernel 一次性完成: dequant + dot product + ReLU + weighted sum
index_score = fp8_index(q_fp8, weights, k_cache, k_scale_cache)
# shape: (B, L, L)

5. 与 MLA 中 RoPE 处理方式的对比

这是 Lightning Indexer 与 MLA 主 Attention 的一个本质结构差异,值得单独讨论。

5.1 MLA 主 Attention 的 RoPE

在 MLA 中(详见上一篇博客),RoPE 通过解耦(Decoupled RoPE)的方式引入:

  • Key 的位置信息 \(\mathbf{k}_t^R\) 是从 wkv_a 投影中独立分支得到的 64 维向量
  • wkv_a 的输出被 split 为 [c_t^KV (512d) | k_t^R (64d)]\(\mathbf{k}_t^R\)单独投影
  • 最终 key = concat[\(\mathbf{k}_{t,i}^C\) (128d, 从 \(\mathbf{c}_t^{KV}\) 展开) ; \(\mathbf{k}_t^R\) (64d, RoPE)] = 192d
  • 最终 query = concat[\(\mathbf{q}_{t,i}^C\) (128d) ; \(\mathbf{q}_{t,i}^R\) (64d, RoPE)] = 192d
  • 顺序:nope 在前,pe 在后

数学表达:

$$\mathbf{k}_{t,i}^{MLA} = [\mathbf{k}_{t,i}^C \;;\; \text{RoPE}(\mathbf{k}_t^R)] \in \mathbb{R}^{192}$$
$$\mathbf{q}_{t,i}^{MLA} = [\mathbf{q}_{t,i}^C \;;\; \text{RoPE}(\mathbf{q}_{t,i}^R)] \in \mathbb{R}^{192}$$

其中 \(\mathbf{k}_t^R = \mathbf{W}^{KR} \mathbf{h}_t \in \mathbb{R}^{64}\) 是独立投影。

设计原因:MLA 中 \(\mathbf{c}_t^{KV}\) 需要被缓存用于推理,如果对 \(\mathbf{c}_t^{KV}\) 展开后的 key 施加 RoPE,则无法缓存压缩表示(因为 RoPE 依赖位置)。因此必须将位置信息解耦为独立分支 \(\mathbf{k}_t^R\),在推理时单独缓存 \(\mathbf{k}_t^R\)(64 维)。

5.2 Lightning Indexer 的 RoPE

Indexer 采用了更简单的方式:

  • Key 通过单一投影得到 128 维向量
  • 在同一向量内部 split:前 64 维施加 RoPE,后 64 维不动
  • 最终 key = concat[\(\mathbf{k}^{I,pe}\) (64d, RoPE) ; \(\mathbf{k}^{I,nope}\) (64d)] = 128d
  • 顺序:pe 在前,nope 在后(与 MLA 相反!)
  • 最后还要经过 Hadamard Rotation 混合所有维度

数学表达:

$$\mathbf{k}^I_s = \mathbf{H} \cdot [\text{RoPE}(\mathbf{k}^{I,pe}_s) \;;\; \mathbf{k}^{I,nope}_s] \in \mathbb{R}^{128}$$

5.3 为什么可以不同?

设计考量 MLA Lightning Indexer
目标 精确的 attention 计算 快速近似排序即可
需要缓存什么 \(\mathbf{c}_t^{KV}\)(位置无关)+ \(\mathbf{k}_t^R\)(位置相关) 整个 \(\mathbf{k}^I_s\)(已含位置)
为什么可以 必须解耦,否则无法压缩缓存 不需要压缩,直接缓存完整 128d key(FP8 只需 128 bytes)
Hadamard 的作用 不需要 将 pe/nope 混合,使 FP8 量化误差分布更均匀

核心区别在于:MLA 需要在推理时从缓存的 \(\mathbf{c}_t^{KV}\) 恢复 key,因此位置信息必须解耦;而 Indexer 直接缓存完整的 key(FP8),不需要解耦。


6. 完整数据流图

h_t ∈ R^7168 (输入隐藏状态)
    │
    ├──→ W^DQ (7168→1536) → RMSNorm → c_t^Q ∈ R^1536  [MLA与Indexer共享]
    │       │
    │       ├──→ [MLA主attention]
    │       │       W^UQ (1536→128×192) → q_MLA ∈ R^(128, 192)
    │       │
    │       └──→ [Lightning Indexer]
    │               W^{I,UQ} (1536→64×128) → reshape → R^(64, 128)
    │               → split[前64 | 后64] → RoPE(前64) → concat
    │               → Hadamard → FP8 → q_I ∈ R^(64, 128)
    │
    ├──→ [Lightning Indexer Key]
    │       W^{I,K} (7168→128) → LayerNorm
    │       → split[前64 | 后64] → RoPE(前64) → concat
    │       → Hadamard → FP8 → k_I ∈ R^128
    │
    └──→ [Lightning Indexer Weight]
            W^{I,w} (7168→64) / √64 → w_I ∈ R^64  [float32]

                            ↓
        I_{t,s} = Σ_j w_I[j] · ReLU(q_I[t,j] · k_I[s])
                            ↓
                    Top-2048 selection
                            ↓
            只对选中的 2048 个 token 做完整 MLA Attention

7. Prefill 和 Decode 时的行为

7.1 Prefill(处理整个输入序列)

在 prefill 阶段,所有 token 同时处理:

# 1. 计算所有 token 的 indexer score (B, L, L)
topk_indices = indexer(x, qr, start_pos, freqs_cis, mask)

# 2. 构造 sparse mask
index_mask = torch.full((B, L, L), float("-inf"))
index_mask.scatter_(-1, topk_indices, 0)  # 选中位置设为0
index_mask += causal_mask                  # 叠加因果mask

# 3. 主 MLA attention 计算 (MHA mode)
q = torch.cat([q_nope, q_pe], dim=-1)     # (B, L, 128, 192)
k = torch.cat([k_nope, k_pe], dim=-1)     # (B, L, 128, 192)
scores = einsum("bshd,bthd->bsht", q, k) * scale
scores += index_mask.unsqueeze(2)          # 未选中位置被mask为-inf
scores = softmax(scores, dim=-1)
output = einsum("bsht,bthd->bshd", scores, v)

7.2 Decode(逐 token 生成)

在 decode 阶段,每次只处理 1 个新 token,但需要和之前所有 token 的缓存做计算:

# 1. 新 token 的 indexer key 入缓存
k_cache[:bsz, end_pos] = k_fp8  # 追加到 indexer key cache

# 2. 计算新 token 与所有历史 token 的 indexer score
topk_indices = indexer(x, qr, start_pos, freqs_cis, mask=None)
# shape: (B, 1, 2048) — 从 end_pos 个候选中选出 2048 个

# 3. 只对选中的 2048 个位置做 MLA attention (MQA mode)
index_mask = torch.full((B, 1, end_pos), float("-inf"))
index_mask.scatter_(-1, topk_indices, 0)
scores += index_mask.unsqueeze(2)
scores = softmax(scores, dim=-1)
output = einsum("bsht,btc->bshc", scores, kv_cache[:bsz, :end_pos])

8. 训练过程

8.1 Dense Warm-up 阶段 (论文 Equation 3)

首先冻结主模型所有参数,只训练 Lightning Indexer。目标是让 indexer 的分数分布与真实 attention 分布对齐:

$$\mathcal{L}_I = \sum_t D_{KL}\left(p_{t,:} \;\|\; \text{Softmax}(I_{t,:})\right) \tag{3}$$

其中 \(p_{t,:} \in \mathbb{R}^t\) 是目标分布:将所有 128 个 MLA attention head 的 attention score 求和,再沿序列维度 L1-归一化:

$$p_{t,s} = \frac{\sum_{i=1}^{n_h} \alpha_{t,s,i}}{\sum_{s'} \sum_{i=1}^{n_h} \alpha_{t,s',i}}$$

其中 \(\alpha_{t,s,i}\) 是主 attention 第 \(i\) 个 head 在位置 \((t,s)\) 的 attention score。

训练参数

  • 学习率:\(10^{-3}\)
  • 训练步数:1000 步
  • 每步数据:16 sequences × 128K tokens
  • 总训练数据量:2.1B tokens

8.2 Sparse Training 阶段 (论文 Equation 4)

解冻主模型参数,同时训练主模型和 indexer,但使用梯度分离(detach)

  • Indexer 只接收 \(\mathcal{L}_I\) 的梯度
  • 主模型只接收语言建模 loss 的梯度

此阶段的 indexer loss 只在已选中的 token 集合 \(\mathcal{S}_t\) 上计算:

$$\mathcal{L}_I = \sum_t D_{KL}\left(p_{t,\mathcal{S}_t} \;\|\; \text{Softmax}(I_{t,\mathcal{S}_t})\right) \tag{4}$$

其中 \(\mathcal{S}_t = \{s \mid I_{t,s} \in \text{Top-k}(I_{t,:})\}\)

训练参数

  • 学习率:\(7.3 \times 10^{-6}\)(比 warm-up 小约 137 倍)
  • 训练步数:15000 步
  • 每步数据:480 sequences × 128K tokens
  • 总训练数据量:943.7B tokens
  • Top-k = 2048(每个 query 只选 2048 个 KV token)

Note:将 indexer 的梯度从主模型的计算图中 detach 是一个重要的设计选择。这避免了 indexer 的学习信号干扰主模型的语言建模能力,也避免了 top-k 选择带来的不可导问题。


9. 效率分析

9.1 计算量对比

组件 每 query token 的计算量 说明
MLA attention (dense) \(128 \times 192 \times L \times 2\) FLOPs 128 heads, 192d per head, \(L\) tokens
Lightning Indexer \(64 \times 128 \times L \times 2\) FLOPs (FP8) 64 heads, 128d per head, \(L\) tokens
MLA attention (sparse) \(128 \times 192 \times 2048 \times 2\) FLOPs 只对 2048 tokens 做完整 attention

Indexer 的计算量约为 dense MLA 的:

$$\frac{64 \times 128}{128 \times 192} = \frac{8192}{24576} = \frac{1}{3}$$

考虑 FP8 vs BF16 的吞吐差异(FP8 约 2x),实际加速约 6x

9.2 内存/KV Cache 对比

组件 每 token KV Cache 精度
MLA 主 attention \(\mathbf{c}_t^{KV}\) (512d) + \(\mathbf{k}_t^R\) (64d) = 576 values BF16 (1152 bytes)
Lightning Indexer \(\mathbf{k}^I_s\) (128d) FP8 (128 bytes)

Indexer 的 KV cache 只有主 attention 的 11%

9.3 端到端效果

论文 Figure 3 显示,当序列长度 > 32K 时,DSA 带来的加速显著。对于 128K 上下文,token 成本约降低 70-80%


10. 设计选择总结

设计选择 原因
ReLU 而非 Softmax Softmax 需要遍历所有位置计算归一化分母;ReLU 逐元素操作,可完全并行,且硬件友好
Multi-Query (共享 Key) 减少 key 的 KV cache(128 bytes/token vs 128d×64heads),且 kernel 实现更高效
FP8 量化 H800 GPU 的 FP8 Tensor Core 吞吐是 BF16 的 2x
Hadamard Rotation 消除量化前数值分布的不均匀性(outlier channels),提升 FP8 精度
复用 \(\mathbf{c}_t^Q\) 避免重复计算 query 压缩,indexer 只增加一个轻量投影
Head 权重 \(w^I_{t,j}\) 不同 head 的重要性可以根据输入动态调整,增加表达力
梯度分离 避免 indexer 训练信号污染主模型的语言建模能力

11. 所有投影矩阵汇总

矩阵 维度 参数量 用途
\(\mathbf{W}^{DQ}\) (共享) \(1536 \times 7168\) 11.0M Query 压缩
\(\mathbf{W}^{I,UQ}\) \(8192 \times 1536\) 12.6M Indexer query 展开
\(\mathbf{W}^{I,K}\) \(128 \times 7168\) 0.9M Indexer key 投影
\(\mathbf{W}^{I,w}\) \(64 \times 7168\) 0.5M Head 权重投影
Indexer 独有总计 14.0M 每层约 14M 参数

对于 61 层模型,Lightning Indexer 总共增加约 \(14M \times 61 \approx 854M\) 参数,不到总模型 671B 的 0.13%


12. 总结

Lightning Indexer 是一个精心设计的"快速预筛选器":

  1. 输入:复用 MLA 的 query 压缩表示 + 独立的 key 投影
  2. 计算:Multi-Query、FP8、ReLU、Hadamard Rotation 四重优化
  3. 输出:每个 query token 选出 2048 个最相关的 preceding token
  4. 效果:以不到 1/6 的计算代价,实现了与 dense attention 相当的效果

其核心设计哲学是:用一个极其廉价但足够准确的近似分数,替代大部分不必要的精确 attention 计算。这使得 DeepSeek-V3.2 能在 128K 上下文长度下,显著降低推理成本。


参考