# 推测解码 background

推测解码的核心思想可以概括为 “小模型快速起草,大模型并行验证” (Draft-and-Verify),并且通过 改进的拒绝采样(Modified Rejection Sampling) 保证输出结果的数学等价性(Lossless)。

具体而言,包含以下几个步骤:

  • 推测(Drafting): 引入一个参数量远小于目标大模型(Target Model)的草稿模型(Draft Model)。草稿模型以自回归的方式快速生成 K 个候选 Token。
  • 验证(Verifying): 将这 K 个候选 Token 连同之前的上下文,作为一个序列(Sequence)一次性输入给目标大模型。大模型通过一次前向传播(Forward Pass),并行计算出这 K 个位置的真实概率分布。
  • 接受与拒绝(Accept/Reject): 比较草稿模型和大模型在每个位置的概率分布。如果草稿模型的预测被大模型认可(基于特定的概率阈值进行采样),则接受该 Token;一旦某个位置的 Token 被拒绝,则丢弃该位置及之后的所有候选 Token,并使用大模型在该位置的真实分布重新采样出一个正确的 Token。
  • 理论保证: Chen et al. 证明了,通过这种特定的拒绝采样算法,推测解码最终生成的 Token 序列的概率分布,与完全使用大模型逐字生成的概率分布在数学上是严格一致的(完全无损)

模型规模对比:

  • 目标大模型(Target Model): 通常在 7B 到 70B+ 级别(例如 LLaMA-2-70B, Chinchilla 70B 等)。
  • 草稿模型(Draft Model): 通常是大模型规模的 1/10 到 1/50,参数量在 100M 到 1B 级别(例如 160M, 1.3B 的同系列微调模型)。

# Roofline 模型与访存墙

GPU 的实际性能由两个因素共同决定:

  • 峰值算力 π\pi(FLOPs/s),例如 A100 FP16 约为 312 TFLOPs/s
  • 显存带宽 β\beta(Bytes/s),例如 A100 HBM2e 约为 2 TB/s

定义算术强度 I=FLOPsBytesI = \frac{\text{FLOPs}}{\text{Bytes}},则实际可达性能为:

P=min(π,I×β)P = \min(\pi, \, I \times \beta)

A100 的 "拐点" 算术强度为 π/β156\pi / \beta \approx 156 FLOPs/Byte。低于此值即为访存受限,高于此值即为计算受限。

# Medusa

# 问题

LLM 推理慢的根本原因在于:自回归解码是 **memory-bandwidth-bound(内存带宽受限)** 的 —— 每生成一个 token 都要把整个模型参数从 HBM 搬到加速器缓存里,但每步只产出一个 token,计算单元严重闲置。

传统的 **speculative decoding(投机解码)** 通过引入一个小的 draft model 来 "猜" 多个 token,然后让大模型并行验证。但这有两个痛点:(1) 需要单独预训练一个对齐的小模型(成本高、有分布偏移);(2) 在分布式系统中维护两个模型很麻烦。

# 创新点

# MEDUSA Heads —— 多个解码头并行预测

不再使用独立的 draft model,而是在原模型最后一层 hidden state (最后一层 Decoder 输出)上接多个轻量级解码头,每个头负责预测后续不同位置的 token。这样既避免了分布偏移问题,又便于集成到现有系统。

Vicuna = LLaMA(预训练模型) + ShareGPT 对话数据监督微调(SFT)

种子任务:为了解决训练指令模型需要大量 “(指令,输入,输出)” 三元组数据的问题

  1. 用 175 条人工写的高质量指令(种子数据),
  2. 喂给一个比较强的模型(当年是 GPT-3),
  3. 让 GPT-3 模仿这些例子,生成新的任务
  4. 再让 GPT-3 自己回答

最后获得 52000 条指令数据

KL 散度:

KL(PQ)KL(P∥Q) 的大小表示 ——“如果真实分布是 PP,我却错误地用了 Q 来表示,那么平均每个结果我会产生多大的对数偏差(按 P 加权)?”

KL 散度 KL(PQ)KL(P∥Q) = 按 PP 的权重,对每个 xx 计算 log(P(x)/Q(x))log(P(x)/Q(x)) 的期望。

它衡量的是:在真实分布 P 认为重要的结果上,Q 相对于 P 的平均对数偏差

  • 数据

    当 target model 有原始 SFT 训练数据集(ground truth)时,直接用;当原始训练数据不可用或模型经过 RLHF 时:用 self-distillation,让 MEDUSA head 学 target model

  • 损失函数

    对 MEDUSA-1(冻结主干):

    LMEDUSA-1=k=1Kλklogpt(k)(yt+k+1)\mathcal{L}_{\text{MEDUSA-1}} = \sum_{k=1}^{K} -\lambda_k \log p_t^{(k)}\left(y_{t+k+1}\right)

    每个 head 对其负责位置的 ground truth 做交叉熵,再用权重 λk\lambda_k 加权。由于 k 越大预测越不确定,损失也越大,避免模型训练中过于关注靠后的 medusa head,所以加权重 λk\lambda_k 来平衡不同 heads 的损失。实践中将 λk\lambda_k 设为常数(如 0.8)的 k 次方。

    对 MEDUSA-2(联合训练,LoRA 微调主干模型):

    LMEDUSA-2=LLM+λ0LMEDUSA-1\mathcal{L}_{\text{MEDUSA-2}} = \mathcal{L}_{\text{LM}} + \lambda_0 \mathcal{L}_{\text{MEDUSA-1}}

    即额外加上主干模型的下一个 token 预测损失,避免破坏原始能力。
    注意!如果用 sft 数据,主干部分的损失用交叉熵;如果用 self-distillation,主干部分的损失为 KL 散度(让微调后的模型对齐原模型分布;若 CE 把 Teacher 的采样结果当真理,会导致分布坍缩)

  • MEDUSA Head 的具体结构 —— 单层带残差连接的前馈网络

    pt(k)=softmax(W2(k)(SiLU(W1(k)ht)+ht))p_t^{(k)} = \mathrm{softmax}\left( W_2^{(k)} \cdot \left( \mathrm{SiLU}\left( W_1^{(k)} \cdot h_t \right) + h_t \right) \right)

    • 输入:htRdh_t \in \mathbb{R}^d(原模型最后一层 hidden state, d 是 hidden size)

    • W1(k)Rd×dW_1^{(k)} \in \mathbb{R}^{d \times d}:第一个线性层(保持维度)

    • 激活函数:SiLU,跟随 Llama 模型的设计

    • 残差连接:把 hth_t 直接加回到 SiLU 输出上

    • W2(k)Rd×VW_2^{(k)} \in \mathbb{R}^{d \times V}:第二个线性层投影到词表大小 V(其实就是一个 LM head)

# Tree Attention —— 树形注意力并行验证多候选

每个 head 都输出 top-k 候选,通过笛卡尔积形成多条候选序列,用一个特殊的 attention mask 让每个 token 只能看到自己所在分支的祖先节点,从而一次前向就并行验证多个候选序列,不需要扩大 batch size。

  • 构造树之后,我们先用校准集(与测试集同分布,但是不相同的一小批数据),先跑一遍统计每个 head 在不同 rank 位置上的 "准确率"(即排第 k 的候选 token 是真实 token 的概率),然后枚举所有可能的树节点,按 "路径上各 节点 经验准确率的乘积" 排序,取前 64 个节点构造稀疏树。

稀疏树的叶子节点可以在任意深度,高概率路径深、低概率路径早早被截。

# Typical Acceptance —— 典型性接受方案

替代 speculative decoding 中的 rejection sampling。使用原模型的预测概率作为衡量标准,并基于预测分布建立一个阈值来判断接受与否。简单说就是只要候选 token 在原模型分布下 "不算太离谱" 就接受,相比严格的拒绝采样能大幅提升接受率,同时保持生成质量。

** 核心想法:** 只要候选 token xx 在大模型的预测分布 p(context)p(\cdot|\text{context}) 里 "不算太离谱",就接受它。

"不算太离谱" 用两个阈值刻画。设大模型在该位置的分布为 pp,候选 token 是 xx,则接受条件是:

p(x)>min(ϵ, δexp(H(p)))p(x) > \min\left(\epsilon,\ \delta \cdot \exp(-H(p))\right)

其中:

  • ϵ\epsilon 是一个绝对阈值(比如 0.09)
  • H(p)H(p) 是分布 pp 的熵
  • δexp(H(p))\delta \cdot \exp(-H(p)) 是一个相对于分布尖锐度的阈值

直觉:

  • 如果大模型很确定(熵小,exp(H)\exp(-H) 大),那阈值就高,要求候选必须是高概率 token。
  • 如果大模型很不确定(熵大,exp(H)\exp(-H) 小),那阈值就放低,因为反正怎么选都行。

todo:eagle,n-gram,mtp