使用 FP8 与 Transformer Engine

H100 GPU 引入了对新数据类型 FP8(8 位浮点)的支持,从而提高了矩阵乘法和卷积的吞吐量。在本例中,我们将介绍 FP8 数据类型,并展示如何将其与 Transformer Engine 一起使用。

FP8 简介

结构

H100 支持的 FP8 数据类型实际上是 2 种不同的数据类型,可用于神经网络训练的不同部分

  • E4M3 - 它由 1 个符号位、4 个指数位和 3 个尾数位组成。它可以存储高达 +/-448 的值和 nan

  • E5M2 - 它由 1 个符号位、5 个指数位和 2 个尾数位组成。它可以存储高达 +/-57344 的值、+/- infnan。动态范围增加的代价是存储值的精度较低。

56588e8d983a42408d5c7f894a6c98a7

图 1:浮点数据类型的结构。所有显示的值(在 FP16、BF16、FP8 E4M3 和 FP8 E5M2 中)都是值 0.3952 的最接近表示。

在训练神经网络期间,可以使用这两种类型。通常,正向激活和权重需要更高的精度,因此 E4M3 数据类型最适合在正向传播期间使用。然而,在反向传播中,流经网络的梯度通常不太容易受到精度损失的影响,但需要更高的动态范围。因此,最好使用 E5M2 数据格式存储它们。H100 TensorCores 提供对这些类型任意组合作为输入的支持,使我们能够使用其首选精度存储每个张量。

混合精度训练 - 快速入门

为了理解如何使用 FP8 训练深度学习模型,首先回顾一下混合精度如何与其他数据类型(尤其是 FP16)一起使用是很有帮助的。

FP16 训练的混合精度配方包含 2 个组成部分:选择应以 FP16 执行的操作和动态损失缩放。

  • 选择以 FP16 精度执行的操作需要分析输出相对于操作输入的数值行为以及预期的性能优势。这使得可以将矩阵乘法、卷积和归一化层等操作标记为安全,同时将 normexp 操作保留为需要高精度的操作。

  • 动态损失缩放能够避免训练期间梯度的溢出和下溢。这些情况可能会发生,因为虽然 FP16 的动态范围足以存储梯度值的分布,但此分布可能以对于 FP16 来说过高或过低的值为中心。缩放损失会将这些分布(通过仅使用 2 的幂而不影响数值)移动到 FP16 中可表示的范围内。

ff1485fcc734453094a1a74f5ebacc41

图 2:缩放损失能够将梯度分布移动到 FP16 数据类型的可表示范围内。

使用 FP8 进行混合精度训练

虽然 FP8 类型提供的动态范围足以存储任何特定的激活或梯度,但它不足以同时存储所有这些。这使得适用于 FP16 的单一损失缩放因子策略对于 FP8 训练来说是不可行的,而是需要为每个 FP8 张量使用不同的缩放因子。

有多种策略可用于选择适用于给定 FP8 张量的缩放因子

  • 即时缩放。此策略根据所生成张量的绝对值的最大值 (amax) 选择缩放因子。在实践中,这是不可行的,因为它需要多次传递数据 - 运算符以更高的精度生成并写出输出,然后找到输出的绝对值的最大值并将其应用于所有值以获得最终的 FP8 输出。这会导致大量开销,严重削弱了使用 FP8 带来的收益。

  • 延迟缩放。此策略根据在先前若干次迭代中看到的绝对值的最大值选择缩放因子。这使得 FP8 计算能够充分发挥性能,但需要将最大值的历史记录存储为 FP8 运算符的附加参数。

a913bd4dd56c4c44b917620d312623dc

图 3:延迟缩放策略。FP8 运算符使用使用在先前若干次迭代中看到的 amax(绝对值的最大值)的历史记录获得的缩放因子,并生成 FP8 输出和当前 amax,该 amax 存储在历史记录中。

正如在图 3 中看到的那样,延迟缩放策略既需要存储 amax 的历史记录,还需要选择一种将该历史记录转换为在下一次迭代中使用的缩放因子的方法。

使用 FP8 与 Transformer Engine

Transformer Engine 库提供了使用延迟缩放策略轻松使用 FP8 数据类型进行训练的工具。

FP8 配方

来自 transformer_engine.common.recipe 模块的 DelayedScaling 配方存储了 FP8 训练所需的所有选项 - 用于计算缩放因子的 amax 历史记录长度、FP8 数据格式等。

[1]:
from transformer_engine.common.recipe import Format, DelayedScaling

fp8_format = Format.HYBRID  # E4M3 during forward pass, E5M2 during backward pass
fp8_recipe = DelayedScaling(fp8_format=fp8_format, amax_history_len=16, amax_compute_algo="max")

此配方随后用于配置 FP8 训练。

FP8 自动类型转换

并非所有操作都可以安全地使用 FP8 执行。Transformer Engine 库提供的所有模块都旨在从 FP8 数据类型中提供最大的性能优势,同时保持准确性。为了启用 FP8 操作,TE 模块需要封装在 fp8_autocast 上下文管理器中。

[2]:
import transformer_engine.pytorch as te
import torch

torch.manual_seed(12345)

my_linear = te.Linear(768, 768, bias=True)

inp = torch.rand((1024, 768)).cuda()

with te.fp8_autocast(enabled=True, fp8_recipe=fp8_recipe):
    out_fp8 = my_linear(inp)

fp8_autocast 上下文管理器隐藏了处理 FP8 的复杂性

  • 所有 FP8 安全的操作都会将其输入转换为 FP8

  • Amax 历史记录已更新

  • 新的缩放因子已计算完毕,可以用于下一次迭代

注意

Transformer Engine 的线性层中对 FP8 的支持目前仅限于形状的两个维度都可以被 16 整除的张量。就完整 Transformer 网络的输入而言,这通常需要将序列长度填充为 16 的倍数。

处理反向传播

当模型在 fp8_autocast 区域内运行时,尤其是在多 GPU 训练中,需要进行一些通信以同步缩放因子和 amax 历史记录。为了在不引入太多开销的情况下执行该通信,fp8_autocast 上下文管理器会在执行通信之前聚合张量。

由于这种聚合,反向调用需要在 fp8_autocast 上下文管理器之外发生。它对计算精度没有影响 - 反向传播的精度由正向传播的精度决定。

[3]:
loss_fp8 = out_fp8.mean()

loss_fp8.backward()  # This backward pass uses FP8, since out_fp8 was calculated inside fp8_autocast

out_fp32 = my_linear(inp)
loss_fp32 = out_fp32.mean()
loss_fp32.backward()  # This backward pass does not use FP8, since out_fp32 was calculated outside fp8_autocast

精度

如果我们将 FP32 和 FP8 执行的结果进行比较,我们会发现它们相对接近,但又有所不同

[4]:
out_fp8
[4]:
tensor([[ 0.2276,  0.2627,  0.3001,  ...,  0.0346,  0.2211,  0.1188],
        [-0.0963, -0.3725,  0.1717,  ...,  0.0901,  0.0522, -0.3472],
        [ 0.4526,  0.3482,  0.5976,  ..., -0.0687, -0.0382,  0.1566],
        ...,
        [ 0.1698,  0.6061,  0.0385,  ..., -0.2875, -0.1152, -0.0260],
        [ 0.0679,  0.2946,  0.2751,  ..., -0.2284,  0.0517, -0.1441],
        [ 0.1865,  0.2353,  0.9172,  ...,  0.1085,  0.1135,  0.1438]],
       device='cuda:0', grad_fn=<_LinearBackward>)
[5]:
out_fp32
[5]:
tensor([[ 0.2373,  0.2674,  0.2980,  ...,  0.0233,  0.2498,  0.1131],
        [-0.0767, -0.3778,  0.1862,  ...,  0.0858,  0.0676, -0.3369],
        [ 0.4615,  0.3593,  0.5813,  ..., -0.0779, -0.0349,  0.1422],
        ...,
        [ 0.1914,  0.6038,  0.0382,  ..., -0.2847, -0.0991, -0.0423],
        [ 0.0864,  0.2895,  0.2719,  ..., -0.2388,  0.0772, -0.1541],
        [ 0.2019,  0.2275,  0.9027,  ...,  0.1022,  0.1300,  0.1444]],
       device='cuda:0', grad_fn=<_LinearBackward>)

发生这种情况是因为在 FP8 情况下,输入和权重在计算之前都转换为 FP8。 如果我们使用 FP8 中可表示的输入(使用 quickstart_utils.py 中定义的函数)而不是原始输入,我们可以看到这一点

[6]:
from quickstart_utils import cast_to_representable

inp_representable = cast_to_representable(inp)
my_linear.weight.data = cast_to_representable(my_linear.weight.data)

out_fp32_representable = my_linear(inp_representable)

print(out_fp32_representable)
tensor([[ 0.2276,  0.2629,  0.3000,  ...,  0.0346,  0.2211,  0.1188],
        [-0.0963, -0.3724,  0.1717,  ...,  0.0901,  0.0522, -0.3470],
        [ 0.4526,  0.3479,  0.5976,  ..., -0.0686, -0.0382,  0.1566],
        ...,
        [ 0.1698,  0.6062,  0.0385,  ..., -0.2876, -0.1152, -0.0260],
        [ 0.0679,  0.2947,  0.2750,  ..., -0.2284,  0.0516, -0.1441],
        [ 0.1865,  0.2353,  0.9170,  ...,  0.1085,  0.1135,  0.1438]],
       device='cuda:0', grad_fn=<_LinearBackward>)

这次差异非常小

[7]:
out_fp8 - out_fp32_representable
[7]:
tensor([[ 4.9591e-05, -1.9073e-04,  9.5367e-05,  ..., -3.8147e-06,
          4.1962e-05,  2.2888e-05],
        [ 2.2888e-05, -3.4332e-05,  2.2888e-05,  ...,  2.6703e-05,
          5.3406e-05, -1.4114e-04],
        [-3.8147e-05,  2.6703e-04, -3.8147e-06,  ..., -5.7220e-05,
          4.1962e-05, -1.9073e-05],
        ...,
        [ 1.1444e-05, -7.2479e-05, -3.8147e-06,  ...,  5.3406e-05,
         -1.5259e-05,  2.2888e-05],
        [ 4.9591e-05, -9.5367e-05,  6.8665e-05,  ..., -1.5259e-05,
          7.6294e-05,  4.5776e-05],
        [-1.5259e-05, -7.6294e-06,  1.8692e-04,  ..., -3.0518e-05,
         -4.5776e-05,  7.6294e-06]], device='cuda:0', grad_fn=<SubBackward0>)

来自 FP8 执行的结果差异在训练过程中并不重要,但最好理解它们,例如,在调试模型期间。