注意#

本文档详细介绍了 TensorRT-LLM 的 PyTorch 后端中自回归模型的多头注意力 (MHA)、多查询注意力 (MQA) 和组查询注意力 (GQA) 的实现。 简单回顾一下,多头注意力涉及一系列批量矩阵乘法、softmax 运算和另一个批量矩阵乘法,如 Attention Is All You Need 论文中所述。多查询注意力 (MQA)组查询注意力 (GQA) 是 MHA 的变体,它使用的 KV 头比查询头少。 TensorRT-LLM 在 tensorrt_llm/_torch/attention_backend/ 中提供了使用不同后端的几种实现。 以下各节将解释如何使用这些实现,并提供有关实现新后端的简要指南。

注意力后端#

目前有三个可用的注意力后端:vanilla 后端、TRT-LLM 后端和 Flashinfer 后端。您可以使用 PyTorchConfig.attn_backend 指定所需的注意力后端。 例如,要使用 Flashinfer 后端,您可以创建一个 PyTorchConfig,其中 attn_backend = "flashinfer",然后将其传递给 LLM 构造函数,如下所示:LLM(pytorch_backend_config=pytorch_config)。 这将启用 Flashinfer 后端用于您的模型。

vanilla 后端 VanillaAttention 是一种参考实现,主要用于飞行中批处理和线性 KV 缓存支持。 虽然它是一个有用的基线,但由于其有限的优化,不建议用于生产环境。

相比之下,Flashinfer 后端 FlashInferAttention 经过性能优化,支持飞行中批处理和分页 KV 缓存。 它还包括以下高级功能

  1. FP8 量化:此功能可以量化输入和 KV 缓存为 FP8 格式,从而显著减少内存使用并提高计算吞吐量。

  2. RoPE 融合:通过将旋转位置嵌入 (RoPE) 直接集成到注意力计算中,此功能增强了效率并减少了开销。

TRT-LLM 后端 TrtllmAttention 是默认后端,支持 Flashinfer 后端中提供的所有功能,同时针对增强的性能进行了进一步优化。 它是生产环境的推荐选择。 此外,它还提供以下高级功能

  1. 融合 QKV 输入:它可以接受单个 QKV 张量作为输入,与使用单独的 Q、K 和 V 张量相比,这更有效。

  2. FP8 输出:它支持以 FP8 格式输出注意力结果,将量化融合到注意力计算过程中。

实现新的注意力后端#

您可以实现一个新的注意力后端来集成其他注意力库。 注意力后端由一个 AttentionBackend 类和一个 AttentionMetadata 类组成。 PyTorch 中有三个阶段涉及注意力后端

  1. 模型构建:在模型的 __init__ 期间,调用 AttentionBackend.__init__ 为每一层创建一个注意力后端。

  2. 元数据准备:在模型的每个前向步骤之前

    1. 如果元数据未初始化,则调用 AttentionMetadata.__init__ 来创建注意力元数据。

    2. 如果使用 CUDA 图,则调用 AttentionMetadata.create_cuda_graph_metadata 将元数据转换为 CUDA 图元数据,它会预分配所有张量,并且可以用于捕获 CUDA 图。 使用 CUDA 图时,请勿在初始预热运行后重新分配存储在 AttentionMetadata 中的任何张量。

    3. 要准备输入和 KV 缓存的参数,请调用 AttentionMetadata.prepare 以从现有元数据和 KV 缓存管理器进行转换。

  3. 单步前向:在每个注意力层的前向传递期间,调用 AttentionBackend.forward 来执行注意力操作。 AttentionMetadata 将作为前向参数提供。

实现 AttentionMetadata#

AttentionMetadata 类存储来自批量输入和 KV 缓存的元数据,供注意力后端使用。 它包含以下预定义的字段

字段

类型

描述

max_num_requests

int

单个批次中的最大请求数。

num_contexts

int

批次中上下文阶段序列的数量。

num_generations

int

批次中生成阶段序列的数量。

max_num_tokens

int

单个批次中所有请求的最大 token 数。

num_tokens

int

批次中的 token 数。

num_ctx_tokens

int

上下文阶段的序列中的 token 数。

kv_cache_manager

KVCacheManager

KV 缓存管理器。

is_cuda_graph

bool

是否启用 CUDA 图。

seq_lens

Tensor

批次中每个序列的长度。 形状为 (batch_size),位于 CPU 内存中。

seq_lens_cuda

Tensor

seq_lens 的副本存储在 GPU 上。

context_lens

Tensor

批次中每个上下文阶段序列的长度。 形状为 (num_contexts)。

position_ids

Optional[Tensor]

每个序列中每个 token 的位置。 如果位置嵌入在后端外部应用,则可能为 None。

request_ids

List[int]

批次中每个序列的请求 ID。

prompt_lens

List[int]

批次中每个序列的 prompt 长度。

kv_cache_params

KVCacheParams

KV 缓存的参数。

AttentionMetadata.__init__ 中,您可以初始化新的注意力元数据的附加字段。例如,Flashinfer 元数据在此处初始化 decode_wrapper。在 AttentionMetadata.prepare 期间,运行时将填充所有预定义的字段,您可以根据这些预定义的字段填充您自定义的字段。例如,Flashinfer 元数据通过结合 context_lensnum_generations 在此处填充 qo_indptr

实现 AttentionBackend#

AttentionBackend 将注意力操作委派给后端实现。

它的 __init__ 接受以下参数

字段

类型

描述

layer_idx

int

模型中注意力层的索引。

num_heads

int

查询头的数量。

head_dim

int

每个注意力头的大小 (hidden_size // num_heads)

num_kv_heads

Optional[int]

KV头的数量。如果为 None,则默认为 num_heads。

quant_config

QuantConfig

可选的量化配置。如果为 None,则不应用量化。

pos_embd_params

PositionalEmbeddingParams

定义如何应用位置嵌入的可选参数。如果为 None,则位置嵌入应由模型在调用后端之前应用。否则,后端负责应用位置嵌入,并且可以缓存 K 而无需先嵌入它。

它的 forward 接受以下参数

字段

类型

描述

q

Tensor

形状为 (num_tokens, num_heads * head_dim) 的查询张量。

k

Tensor

形状为 (num_tokens, num_kv_heads * head_dim) 的键张量。

v

Tensor

形状为 (num_tokens, num_kv_heads * head_dim) 的值张量。

metadata

AttentionMetadata

注意力操作的元数据。

attention_mask

AttentionMask

可选的注意力掩码。如果为 None,则应用因果掩码。

例如,Flashinfer 后端调用 append_paged_kv_cache,然后调用 wrapper 的 run 以在此处执行注意力操作。