注意#
本文档详细介绍了 TensorRT-LLM 的 PyTorch 后端中自回归模型的多头注意力 (MHA)、多查询注意力 (MQA) 和组查询注意力 (GQA) 的实现。 简单回顾一下,多头注意力涉及一系列批量矩阵乘法、softmax 运算和另一个批量矩阵乘法,如 Attention Is All You Need 论文中所述。多查询注意力 (MQA) 和 组查询注意力 (GQA) 是 MHA 的变体,它使用的 KV 头比查询头少。 TensorRT-LLM 在 tensorrt_llm/_torch/attention_backend/
中提供了使用不同后端的几种实现。 以下各节将解释如何使用这些实现,并提供有关实现新后端的简要指南。
注意力后端#
目前有三个可用的注意力后端:vanilla 后端、TRT-LLM 后端和 Flashinfer 后端。您可以使用 PyTorchConfig.attn_backend
指定所需的注意力后端。 例如,要使用 Flashinfer 后端,您可以创建一个 PyTorchConfig
,其中 attn_backend = "flashinfer"
,然后将其传递给 LLM
构造函数,如下所示:LLM(pytorch_backend_config=pytorch_config)
。 这将启用 Flashinfer 后端用于您的模型。
vanilla 后端 VanillaAttention
是一种参考实现,主要用于飞行中批处理和线性 KV 缓存支持。 虽然它是一个有用的基线,但由于其有限的优化,不建议用于生产环境。
相比之下,Flashinfer 后端 FlashInferAttention
经过性能优化,支持飞行中批处理和分页 KV 缓存。 它还包括以下高级功能
FP8 量化:此功能可以量化输入和 KV 缓存为 FP8 格式,从而显著减少内存使用并提高计算吞吐量。
RoPE 融合:通过将旋转位置嵌入 (RoPE) 直接集成到注意力计算中,此功能增强了效率并减少了开销。
TRT-LLM 后端 TrtllmAttention
是默认后端,支持 Flashinfer 后端中提供的所有功能,同时针对增强的性能进行了进一步优化。 它是生产环境的推荐选择。 此外,它还提供以下高级功能
融合 QKV 输入:它可以接受单个 QKV 张量作为输入,与使用单独的 Q、K 和 V 张量相比,这更有效。
FP8 输出:它支持以 FP8 格式输出注意力结果,将量化融合到注意力计算过程中。
实现新的注意力后端#
您可以实现一个新的注意力后端来集成其他注意力库。 注意力后端由一个 AttentionBackend
类和一个 AttentionMetadata
类组成。 PyTorch 中有三个阶段涉及注意力后端
模型构建:在模型的
__init__
期间,调用AttentionBackend.__init__
为每一层创建一个注意力后端。元数据准备:在模型的每个前向步骤之前
如果元数据未初始化,则调用
AttentionMetadata.__init__
来创建注意力元数据。如果使用 CUDA 图,则调用
AttentionMetadata.create_cuda_graph_metadata
将元数据转换为 CUDA 图元数据,它会预分配所有张量,并且可以用于捕获 CUDA 图。 使用 CUDA 图时,请勿在初始预热运行后重新分配存储在AttentionMetadata
中的任何张量。要准备输入和 KV 缓存的参数,请调用
AttentionMetadata.prepare
以从现有元数据和 KV 缓存管理器进行转换。
单步前向:在每个注意力层的前向传递期间,调用
AttentionBackend.forward
来执行注意力操作。AttentionMetadata
将作为前向参数提供。
实现 AttentionMetadata
#
AttentionMetadata
类存储来自批量输入和 KV 缓存的元数据,供注意力后端使用。 它包含以下预定义的字段
字段 |
类型 |
描述 |
---|---|---|
max_num_requests |
int |
单个批次中的最大请求数。 |
num_contexts |
int |
批次中上下文阶段序列的数量。 |
num_generations |
int |
批次中生成阶段序列的数量。 |
max_num_tokens |
int |
单个批次中所有请求的最大 token 数。 |
num_tokens |
int |
批次中的 token 数。 |
num_ctx_tokens |
int |
上下文阶段的序列中的 token 数。 |
kv_cache_manager |
KVCacheManager |
KV 缓存管理器。 |
is_cuda_graph |
bool |
是否启用 CUDA 图。 |
seq_lens |
Tensor |
批次中每个序列的长度。 形状为 (batch_size),位于 CPU 内存中。 |
seq_lens_cuda |
Tensor |
|
context_lens |
Tensor |
批次中每个上下文阶段序列的长度。 形状为 ( |
position_ids |
Optional[Tensor] |
每个序列中每个 token 的位置。 如果位置嵌入在后端外部应用,则可能为 None。 |
request_ids |
List[int] |
批次中每个序列的请求 ID。 |
prompt_lens |
List[int] |
批次中每个序列的 prompt 长度。 |
kv_cache_params |
KVCacheParams |
KV 缓存的参数。 |
在 AttentionMetadata.__init__
中,您可以初始化新的注意力元数据的附加字段。例如,Flashinfer 元数据在此处初始化 decode_wrapper
。在 AttentionMetadata.prepare
期间,运行时将填充所有预定义的字段,您可以根据这些预定义的字段填充您自定义的字段。例如,Flashinfer 元数据通过结合 context_lens
和 num_generations
在此处填充 qo_indptr
。
实现 AttentionBackend
#
AttentionBackend
将注意力操作委派给后端实现。
它的 __init__
接受以下参数
字段 |
类型 |
描述 |
---|---|---|
layer_idx |
int |
模型中注意力层的索引。 |
num_heads |
int |
查询头的数量。 |
head_dim |
int |
每个注意力头的大小 |
num_kv_heads |
Optional[int] |
KV头的数量。如果为 None,则默认为 num_heads。 |
quant_config |
QuantConfig |
可选的量化配置。如果为 None,则不应用量化。 |
pos_embd_params |
PositionalEmbeddingParams |
定义如何应用位置嵌入的可选参数。如果为 None,则位置嵌入应由模型在调用后端之前应用。否则,后端负责应用位置嵌入,并且可以缓存 K 而无需先嵌入它。 |
它的 forward
接受以下参数
字段 |
类型 |
描述 |
---|---|---|
q |
Tensor |
形状为 |
k |
Tensor |
形状为 |
v |
Tensor |
形状为 |
metadata |
AttentionMetadata |
注意力操作的元数据。 |
attention_mask |
AttentionMask |
可选的注意力掩码。如果为 None,则应用因果掩码。 |
例如,Flashinfer 后端调用 append_paged_kv_cache
,然后调用 wrapper 的 run
以在此处执行注意力操作。