# 推测解码 background
推测解码的核心思想可以概括为 “小模型快速起草,大模型并行验证” (Draft-and-Verify),并且通过 改进的拒绝采样(Modified Rejection Sampling) 保证输出结果的数学等价性(Lossless)。
具体而言,包含以下几个步骤:
- 推测(Drafting): 引入一个参数量远小于目标大模型(Target Model)的草稿模型(Draft Model)。草稿模型以自回归的方式快速生成 K 个候选 Token。
- 验证(Verifying): 将这 K 个候选 Token 连同之前的上下文,作为一个序列(Sequence)一次性输入给目标大模型。大模型通过一次前向传播(Forward Pass),并行计算出这 K 个位置的真实概率分布。
- 接受与拒绝(Accept/Reject): 比较草稿模型和大模型在每个位置的概率分布。如果草稿模型的预测被大模型认可(基于特定的概率阈值进行采样),则接受该 Token;一旦某个位置的 Token 被拒绝,则丢弃该位置及之后的所有候选 Token,并使用大模型在该位置的真实分布重新采样出一个正确的 Token。
- 理论保证: Chen et al. 证明了,通过这种特定的拒绝采样算法,推测解码最终生成的 Token 序列的概率分布,与完全使用大模型逐字生成的概率分布在数学上是严格一致的(完全无损)。
模型规模对比:
- 目标大模型(Target Model): 通常在 7B 到 70B+ 级别(例如 LLaMA-2-70B, Chinchilla 70B 等)。
- 草稿模型(Draft Model): 通常是大模型规模的 1/10 到 1/50,参数量在 100M 到 1B 级别(例如 160M, 1.3B 的同系列微调模型)。
# Roofline 模型与访存墙
GPU 的实际性能由两个因素共同决定:
- 峰值算力 (FLOPs/s),例如 A100 FP16 约为 312 TFLOPs/s
- 显存带宽 (Bytes/s),例如 A100 HBM2e 约为 2 TB/s
定义算术强度 ,则实际可达性能为:
A100 的 "拐点" 算术强度为 FLOPs/Byte。低于此值即为访存受限,高于此值即为计算受限。
# Medusa
# 问题
LLM 推理慢的根本原因在于:自回归解码是 **memory-bandwidth-bound(内存带宽受限)** 的 —— 每生成一个 token 都要把整个模型参数从 HBM 搬到加速器缓存里,但每步只产出一个 token,计算单元严重闲置。
传统的 **speculative decoding(投机解码)** 通过引入一个小的 draft model 来 "猜" 多个 token,然后让大模型并行验证。但这有两个痛点:(1) 需要单独预训练一个对齐的小模型(成本高、有分布偏移);(2) 在分布式系统中维护两个模型很麻烦。
# 创新点
# MEDUSA Heads —— 多个解码头并行预测
不再使用独立的 draft model,而是在原模型最后一层 hidden state (最后一层 Decoder 输出)上接多个轻量级解码头,每个头负责预测后续不同位置的 token。这样既避免了分布偏移问题,又便于集成到现有系统。
Vicuna = LLaMA(预训练模型) + ShareGPT 对话数据监督微调(SFT)
种子任务:为了解决训练指令模型需要大量 “(指令,输入,输出)” 三元组数据的问题
- 用 175 条人工写的高质量指令(种子数据),
- 喂给一个比较强的模型(当年是 GPT-3),
- 让 GPT-3 模仿这些例子,生成新的任务
- 再让 GPT-3 自己回答
最后获得 52000 条指令数据
KL 散度:
的大小表示 ——“如果真实分布是 ,我却错误地用了 Q 来表示,那么平均每个结果我会产生多大的对数偏差(按 P 加权)?”
KL 散度 = 按 的权重,对每个 计算 的期望。
它衡量的是:在真实分布 P 认为重要的结果上,Q 相对于 P 的平均对数偏差。
-
数据
当 target model 有原始 SFT 训练数据集(ground truth)时,直接用;当原始训练数据不可用或模型经过 RLHF 时:用 self-distillation,让 MEDUSA head 学 target model
-
损失函数
对 MEDUSA-1(冻结主干):
每个 head 对其负责位置的 ground truth 做交叉熵,再用权重 加权。由于 k 越大预测越不确定,损失也越大,避免模型训练中过于关注靠后的 medusa head,所以加权重 来平衡不同 heads 的损失。实践中将 设为常数(如 0.8)的 k 次方。
对 MEDUSA-2(联合训练,LoRA 微调主干模型):
即额外加上主干模型的下一个 token 预测损失,避免破坏原始能力。
注意!如果用 sft 数据,主干部分的损失用交叉熵;如果用 self-distillation,主干部分的损失为 KL 散度(让微调后的模型对齐原模型分布;若 CE 把 Teacher 的采样结果当真理,会导致分布坍缩) -
MEDUSA Head 的具体结构 —— 单层带残差连接的前馈网络
-
输入:(原模型最后一层 hidden state, d 是 hidden size)
-
:第一个线性层(保持维度)
-
激活函数:SiLU,跟随 Llama 模型的设计
-
残差连接:把 直接加回到 SiLU 输出上
-
:第二个线性层投影到词表大小 V(其实就是一个 LM head)
-
# Tree Attention —— 树形注意力并行验证多候选
每个 head 都输出 top-k 候选,通过笛卡尔积形成多条候选序列,用一个特殊的 attention mask 让每个 token 只能看到自己所在分支的祖先节点,从而一次前向就并行验证多个候选序列,不需要扩大 batch size。
- 构造树之后,我们先用校准集(与测试集同分布,但是不相同的一小批数据),先跑一遍统计每个 head 在不同 rank 位置上的 "准确率"(即排第 k 的候选 token 是真实 token 的概率),然后枚举所有可能的树节点,按 "路径上各 节点 经验准确率的乘积" 排序,取前 64 个节点构造稀疏树。
稀疏树的叶子节点可以在任意深度,高概率路径深、低概率路径早早被截。
# Typical Acceptance —— 典型性接受方案
替代 speculative decoding 中的 rejection sampling。使用原模型的预测概率作为衡量标准,并基于预测分布建立一个阈值来判断接受与否。简单说就是只要候选 token 在原模型分布下 "不算太离谱" 就接受,相比严格的拒绝采样能大幅提升接受率,同时保持生成质量。
** 核心想法:** 只要候选 token 在大模型的预测分布 里 "不算太离谱",就接受它。
"不算太离谱" 用两个阈值刻画。设大模型在该位置的分布为 ,候选 token 是 ,则接受条件是:
其中:
- 是一个绝对阈值(比如 0.09)
- 是分布 的熵
- 是一个相对于分布尖锐度的阈值
直觉:
- 如果大模型很确定(熵小, 大),那阈值就高,要求候选必须是高概率 token。
- 如果大模型很不确定(熵大, 小),那阈值就放低,因为反正怎么选都行。
todo:eagle,n-gram,mtp