推测采样#
关于推测采样#
推测采样(也称为推测解码)是一组旨在允许在每次前向传播迭代中生成多个令牌的技术。这可以减少每个令牌的平均延迟,**尤其是在 GPU 因批量较小而未得到充分利用的情况下。**
推测采样涉及使用比重复执行目标大型语言模型 (LLM) 效率显著更高的方法来预测一系列未来令牌,称为草稿令牌。然后,通过在一次前向传播中用目标 LLM 处理这些草稿令牌,对它们进行集体验证。其基本假设有两点
并发处理多个草稿令牌的速度与处理单个令牌一样快
在整个生成过程中,多个草稿令牌将成功通过验证
如果第一个假设成立,推测解码的延迟不会比标准方法差。如果第二个假设成立,输出令牌生成将统计上每次前向传播推进多于一个令牌。两者的结合使得推测解码能够降低延迟。
TensorRT-LLM 支持几种生成草稿令牌的方法,包括
利用一个更小的辅助模型,称为草稿模型方法。有关更多信息,请参阅《通过推测解码加速 Transformer 推理》论文。
实现额外的语言模型头来预测未来位置的令牌
利用提示令牌作为草稿令牌。有关更多信息,请参阅Prompt Lookup Decoding。
利用类似 Jacobi 的解码方法,使用同一模型预测和验证草稿令牌,无需额外微调。请参阅《使用 Lookahead 解码打破 LLM 推理的顺序依赖》。
性能改进#
需要注意的是,推测解码技术的有效性高度依赖于具体的任务。例如,在代码补全场景中预测后续令牌可能比生成文章摘要更简单。
此外,当将 Medusa 与可能不如 TensorRT-LLM 精细调优的标准 PyTorch 模型实现集成时,潜在的时间节省会更加显著。
草稿-目标模型#
草稿-目标模型涉及使用两个独立训练但共享相同词汇表的模型(一个较小的草稿模型和一个较大的目标模型)。例如,GPT 125M / 6.7B 模型可以分别用作草稿模型和目标模型。
草稿模型和目标模型的管理通过两个独立的 Executor
实例来简化。您必须有效地协调草稿模型和目标模型之间的交互。最初,查询草稿模型以生成最多 K
个草稿令牌。然后,这些令牌被转发到目标模型进行验证。验证后,目标模型可能返回最多 K+1
个令牌。随后,现在已更新接受令牌的提示被发送回草稿模型,以启动新草稿令牌的生成。此迭代过程持续进行,直到满足预定义的停止条件。此编排过程的一个示例可在 TensorRT-LLM Triton 后端中找到。
我们目前提供两种运行草稿-目标模型的方式:在 Triton Inference Server 中使用 TensorRT-LLM-BLS,或直接使用 TensorRT-LLM。详细运行步骤可在 examples/draft_target_model/README.md 中找到,代码可在 examples/prompt_lookup/run_dtm_pld.py 中找到。
提示查找解码#
提示查找推测解码在生成后续输出时,直接从输入提示和之前生成的输出中复制作为草稿令牌。它的工作原理类似于草稿-目标模型,但只涉及一个目标 LLM 模型,无需额外微调。提示查找受益于输入提示和输出之间具有较高 n-gram 重叠的场景,例如摘要、文档问答、多轮聊天、代码编辑等。
请参阅 examples/prompt_lookup/README.md 中的文档,代码可在 examples/prompt_lookup/run_dtm_pld.py 中找到。
Medusa#
这种方法利用单个模型同时生成和验证草稿令牌。它通过添加多个额外的语言模型头(称为 Medusa 头)来增强现有模型。这些额外的头经过训练用于预测未来令牌,而基础模型保持不变。具体来说,第一个 Medusa 头负责预测紧随其后的令牌,第二个头预测再下一个令牌,依此类推。使用 K
个 Medusa 头,模型可以预测最多 K
个未来的令牌。在迭代 i
期间由 Medusa 头生成的草稿令牌将在随后的迭代 i+1
中得到验证并可能被接受。
Medusa 策略的真正潜力在于每个头使用不止一个令牌时得以实现,采用 TopK 方法创建多个潜在路径,实质上形成一个树状结构,而不是像草稿模型方法中看到的单一线性路径。为了减少冗余计算,许多通常共享相同前缀的路径被合并到一条路径中。这是通过应用带有表示各种路径的稀疏掩码的注意力机制来实现的。Medusa 树形成的稀疏掩码将在后面详细描述。
通过同时验证多条路径,每次迭代接受多个令牌的可能性增加,但这需要额外的计算开销。
必须认识到,随着潜在路径的数量随 K
呈指数级增长,不必探索或验证所有路径。管理这种复杂性的推荐策略是通过仅关注具有更高概率令牌的路径来修剪树。
您必须在想要探索的树的广度和深度与更大树对特定应用程序整体性能的影响之间取得平衡。
在 TensorRT-LLM 的 Medusa 实现中,树的配置是一个运行时参数。这种灵活性使您可以进行实验并确定适用于您用例的最佳树结构,然后可以在生产环境中使用该结构。
Medusa 树#
考虑以下图表,它说明了基础模型最后一层的隐藏状态如何传递给基础模型的语言模型 (LM) 头以及四个 Medusa 头 (MHs)。
在此示例中
令牌
l0
表示模型生成的实际令牌。所有其他令牌,表示为phk
,是 MHs 的预测,其中h
表示 Medusa 头索引(从 1 开始),k
表示 TopK 选择索引(从 0 开始)。使用了四个 MHs,这意味着模型正在预测四个未来的令牌。
前两个 MHs 使用 Top-2 预测,而后两个使用 Top-1。例如,
p10
和p11
是第一个 Medusa 头 (MH1) 的最高和次高预测。总共探索了四条路径,这比使用完整二叉树(假设所有 MHs 都使用 Top-2 预测)时将探索的 16 条路径要少。
由于其中一些路径可能会被接受,因此有十个潜在的候选,称为
medusa_choices
。每一步可以接受的令牌数量(包括真实令牌)从 1 个(如果所有 Medusa 预测都不正确)到 5 个(如果所有预测都正确)。
在生成阶段,模型接收 10 个令牌的输入,这些输入对应于每个候选路径的最后一个令牌,而不仅仅是一个令牌。
在 TensorRT-LLM 中,您可以选择通过提供所有 Medusa 选择或仅指定唯一路径来定义此类树。
由于每个候选/路径都以真实令牌 (
l0
) 开始,因此无需单独指定它。对于预测的令牌,只需 TopK 索引即可。例如,要指定路径
l0p10p21p30
,可以使用[0,1,0]
。要指定路径l0p11p20
,可以使用[1,0]
。要指定示例中的所有 4 条路径,请使用
medusa_choices=[[0,0,0,0], [0,1,0], [1,0], [1,1]]
。还可以像 Medusa 仓库那样,显式指定所有候选。例如,
medusa_choices=[[0], [0,0], [0,0,0], [0,0,0,0], [0,1], [0,1,0], [1], [1,0], [1,1]]
。注意,在显式指定所有候选时,**我们不包含空[]
候选**,这对应于只接受真实令牌的情况,即所有来自 MHs 的预测都是错误的。因此,只指定了9
个候选。
目前仅在 Python 运行时中支持只指定路径而非所有选择的方式。
在 TensorRT-LLM 中使用 Medusa#
有关使用 Python 运行时构建和执行 Medusa 的指导,请查阅 Medusa README。在使用 C++ API 和 Inflight Fused Batching (IFB) 时,需要在模型配置中显式定义 medusa_choices
。有关详细说明,请参阅 TensorRT-LLM 后端中的模型配置了解更多详情。
限制#
TensorRT-LLM 目前仅支持 Vicuna(经过微调的 LLaMA)的 Medusa。然而,与任何新模型类似,您可以遵循相同的方法来定义您自己的 Medusa 模型并使用 TensorRT-LLM 进行部署。
我们在验证阶段仅匹配令牌,即
medusa_temperature=0
。集束搜索与 Medusa **不**兼容。
ReDrafter#
ReDrafter 方法通过使用同一模型预测和验证令牌来增强单模型 Medusa 方法。然而,与 Medusa 不同的是,它使用循环预测器预测草稿令牌,其中每个草稿令牌都依赖于前一个。此方法还允许使用集束搜索来识别更突出的草稿令牌。有关更多详细信息,请阅读《ReDrafter》论文。
TensorRT-LLM 实现了 ReDrafter 模型,使得 logits 预测、集束搜索和草稿令牌接受都在 TensorRT 引擎内部执行。这与标准模型推理形成对比,后者仅预测 logits 并在引擎外部执行解码。由于引擎预测的是显式的草稿令牌,而不是从 logits 解码的隐式令牌,因此我们将这种推测解码方法归类为 explicit_draft_tokens
。请访问 ReDrafter README,了解有关构建和运行模型的信息。ReDrafter 支持 Inflight Fused Batching 运行时和 Python 静态批量运行时。
EAGLE#
EAGLE 方法通过使用同一模型预测和验证令牌来增强单模型 Medusa 方法。与 ReDrafter 类似,它使用循环预测器预测草稿令牌,其中每个草稿令牌都依赖于前一个。然而,与 ReDrafter 不同的是,它使用单层 Transformer 模型从之前的隐藏状态和解码的令牌中预测草稿令牌。在 EAGLE-1 中,解码树需要在解码过程中已知。在 EAGLE-2 中,通过沿集束搜索最可能的假设来在执行过程中组装此树。
与 ReDrafter 类似,TensorRT-LLM 实现了 EAGLE 模型,使得 logits 预测、草稿令牌接受和草稿令牌生成都在 TensorRT 引擎内部执行。EAGLE-1 和 EAGLE-2 都受支持,而 EAGLE-2 目前处于实验阶段。请访问 EAGLE README,了解有关构建和运行模型的信息。
Lookahead 解码#
Lookahead 解码算法通过同一模型中的两个并行计算分支运行:一个 Lookahead 分支使用固定大小的二维窗口生成 n-gram,一个验证分支验证有希望的 n-gram 候选。此方法无需额外的模型训练或微调,并且可以为任何自回归模型启用。请参阅 Lookahead 解码 README,了解有关构建和运行模型的信息。