数值精度#
本文档描述了 TensorRT-LLM 中实现的各种量化方案,并包含了不同模型的支持矩阵。
FP32、FP16 和 BF16#
TensorRT-LLM 中实现的各种模型都使用 32 位 IEEE 浮点数 (FP32)。当检查点可用时,模型还支持 16 位 IEEE 浮点数 (FP16) 和 16 位 Bfloat16 (BF16),如此处所述。
量化与反量化 (Q/DQ)#
给定一个浮点数 x
和一个浮点缩放因子 s
,TensorRT-LLM 实现 INT8 量化如下:
q = int8.satfinite(x * s)
给定一个 INT8 数 q
和一个浮点缩放因子 s
,TensorRT-LLM 实现 INT8 反量化到浮点 (FP) 类型如下:
x = static_cast<FP>(q) * s
给定一个形状为 M x N
(M
行 N
列)的矩阵(二维张量),其中 M
是令牌数,N
是通道数。TensorRT-LLM 提供了以下三种模式来量化和反量化张量的元素:
逐张量:对所有元素使用一个单一的缩放因子,
逐令牌:对每个令牌使用不同的缩放因子。在这种情况下有
M
个缩放因子,逐通道:对每个通道使用不同的缩放因子。在这种情况下有
N
个缩放因子。
请注意,逐令牌和逐通道缩放模式可以一起使用(即它们不是互斥的)。
在伪代码中,这三种不同模式的量化可以实现如下:
# Per-tensor scaling.
for mi in range(M):
for ni in range(N):
q[mi][ni] = int8.satfinite(x[mi][ni] * s)
# Per-token scaling.
for mi in range(M):
for ni in range(N):
q[mi][ni] = int8.satfinite(x[mi][ni] * s[mi])
# Per-channel scaling.
for mi in range(M):
for ni in range(N):
q[mi][ni] = int8.satfinite(x[mi][ni] * s[ni])
INT8 SmoothQuant (W8A8)#
SmoothQuant 技术在https://arxiv.org/abs/2211.10438 中介绍。它是一种使用 INT8 对激活和权重进行推理,同时保持网络(在下游任务上)准确性的方法。
如研究论文中所述,必须对模型权重应用预处理。TensorRT-LLM 包含脚本,用于准备模型以使用 SmoothQuant 方法运行。
如何为 GPT、GPT-J 和 LLaMA 启用 SmoothQuant 的示例可以在该版本的 examples/quantization 文件夹中找到。
仅权重 INT4 和 INT8 (W4A16 和 W8A16)#
仅权重 INT4 和 INT8 技术包括对模型的权重进行量化,并在线性层 (Matmuls) 中即时反量化这些权重。激活使用浮点值 (FP16 或 BF16) 编码。
要使用仅权重 INT4/INT8 方法,用户必须确定用于量化和反量化模型权重的缩放因子。
GPTQ 和 AWQ (W4A16)#
GPTQ 和 AWQ 技术分别在https://arxiv.org/abs/2210.17323 和 https://arxiv.org/abs/2306.00978 中介绍。TensorRT-LLM 支持线性层中的逐组缩放因子和零偏移,以实现 GPTQ 和 AWQ 方法。有关详细信息,请参阅 WeightOnlyGroupwiseQuantMatmulPlugin 插件和相应的 weight_only_groupwise_quant_matmul
Python 函数。
本版本包含将 GPTQ 应用于 GPT-NeoX 和 LLaMA-v2 的示例,以及使用 AWQ 处理 GPT-J 的示例。这些示例是实验性实现,可能会在未来的版本中演变。
FP8 (Hopper)#
此版本的 TensorRT-LLM 包含了适用于 GPT-NeMo、GPT-J 和 LLaMA 的 FP8 实现。这些示例可以在 examples/quantization 中找到。
NVFP4 (Blackwell)#
Llama 和 Mixtral 可以使用 NVFP4 数据类型运行。这些示例可以在 Llama 示例中找到。
支持矩阵#
本版本 TensorRT-LLM 包含以下示例:
模型 |
FP32 |
FP16 |
BF16 |
FP8 |
NVFP4 |
W8A8 SQ |
W8A16 |
W4A16 |
W4A16 AWQ |
W4A16 GPTQ |
---|---|---|---|---|---|---|---|---|---|---|
Baichuan |
是 |
是 |
是 |
是 |
. |
是 |
是 |
是 |
是 |
是 |
BERT |
是 |
是 |
是 |
. |
. |
. |
. |
. |
. |
. |
BLIP-2 |
是 |
是 |
是 |
. |
. |
. |
. |
. |
. |
. |
BLOOM |
是 |
是 |
是 |
是 |
. |
是 |
是 |
是 |
. |
. |
ChatGLM |
是 |
是 |
是 |
. |
. |
. |
. |
. |
. |
. |
ChatGLM-v2 |
是 |
是 |
是 |
. |
. |
. |
. |
. |
. |
. |
ChatGLM-v3 |
是 |
是 |
是 |
. |
. |
. |
. |
. |
. |
. |
DBRX |
是 |
是 |
是 |
. |
. |
. |
是 |
是 |
. |
. |
Falcon |
是 |
是 |
是 |
是 |
. |
. |
是 |
是 |
是 |
. |
Flan-T5 |
是 |
是 |
是 |
. |
. |
. |
. |
. |
. |
. |
Gemma |
是 |
是 |
是 |
是 |
. |
是 |
是 |
是 |
是 |
. |
GPT |
是 |
是 |
是 |
是 |
. |
是 |
是 |
是 |
. |
. |
GPT-J |
是 |
是 |
是 |
是 |
. |
是 |
是 |
是 |
是 |
. |
GPT-NeMo |
是 |
是 |
是 |
. |
. |
. |
. |
. |
. |
. |
GPT-NeoX |
是 |
是 |
是 |
. |
. |
. |
. |
. |
. |
是 |
InternLM |
是 |
是 |
是 |
. |
. |
是 |
是 |
是 |
. |
. |
InternLM2 |
是 |
是 |
是 |
. |
. |
. |
. |
. |
. |
. |
LLaMA |
是 |
是 |
是 |
是 |
是 |
是 |
是 |
是 |
是 |
是 |
LLaMA-v2 |
是 |
是 |
是 |
是 |
是 |
是 |
是 |
是 |
是 |
是 |
Mamba |
是 |
是 |
是 |
. |
. |
. |
. |
. |
. |
. |
Mistral |
是 |
是 |
是 |
是 |
. |
是 |
是 |
是 |
是 |
. |
Mixtral |
是 |
是 |
是 |
是 |
是 |
. |
是 |
是 |
. |
. |
MPT |
是 |
是 |
是 |
是 |
. |
是 |
是 |
是 |
是 |
. |
OPT |
是 |
是 |
是 |
. |
. |
. |
. |
. |
. |
. |
Phi |
是 |
是 |
是 |
. |
. |
. |
. |
. |
. |
. |
Qwen |
是 |
是 |
是 |
. |
. |
是 |
是 |
是 |
是 |
是 |
RecurrentGemma |
是 |
是 |
是 |
是 |
. |
是 |
. |
. |
是 |
. |
Replit Code |
是 |
是 |
是 |
. |
. |
. |
. |
. |
. |
. |
SantaCoder |
是 |
是 |
是 |
. |
. |
. |
是 |
是 |
. |
. |
Skywork |
是 |
是 |
是 |
. |
. |
. |
. |
. |
. |
. |
StarCoder1 |
是 |
是 |
是 |
. |
. |
. |
是 |
是 |
. |
. |
StarCoder2 |
是 |
是 |
是 |
是 |
. |
. |
是 |
是 |
. |
. |
T5 |
是 |
是 |
是 |
. |
. |
. |
. |
. |
. |
. |
Whisper |
是 |
是 |
是 |
. |
. |
. |
是 |
是 |
. |
. |
BLIP2-OPT |
是 |
是 |
是 |
. |
. |
. |
. |
. |
. |
. |
BLIP2-T5 |
是 |
是 |
是 |
. |
. |
. |
. |
. |
. |
. |
LLaVA |
是 |
是 |
是 |
是 |
. |
是 |
是 |
是 |
是 |
是 |
VILA |
是 |
是 |
是 |
是 |
. |
是 |
是 |
是 |
是 |
是 |
Nougat |
是 |
是 |
是 |
. |
. |
. |
. |
. |
. |
. |
注意:多模态模型(BLIP2-OPT/BLIP2-T5/LLaVA/VILA/Nougat)的视觉组件默认使用 FP16。语言组件决定了给定多模态模型支持哪些量化方法。
技术细节:QuantMode 标志#
量化方法由 QuantMode
标志控制。不同的字段如下:
INT4_WEIGHTS
,权重被量化到 4 位 (W4A*),INT8_WEIGHTS
,权重被量化到 8 位 (W8A*),ACTIVATIONS
,激活被量化到 8 位 (W*A8),PER_CHANNEL
,缩放因子按通道定义,PER_TOKEN
,缩放因子按令牌定义,PER_GROUP
,缩放因子按组定义。
还有三个附加标志用于控制 TensorRT-LLM:
INT8_KV_CACHE
,K/V 缓存使用 8 位整数存储 K 和 V,FP8_KV_CACHE
,K/V 缓存使用 8 位浮点数存储 K 和 V,FP8_QDQ
,TensorRT-LLM 依赖于 TensorRT 中 Q/DQ 节点的自动融合。