通用 API

class transformer_engine.common.recipe.Format

支持的 FP8 格式。

  • E4M3 – 所有 FP8 张量均为 e4m3 格式

  • E5M2 – 所有 FP8 张量均为 e5m2 格式

  • HYBRID – 前向传递中的 FP8 张量为 e4m3 格式,后向传递中的 FP8 张量为 e5m2 格式

class transformer_engine.common.recipe.DelayedScaling(margin=0, interval=1, fp8_format=Format.E4M3, amax_history_len=1024, amax_compute_algo='max', scaling_factor_compute_algo=None, override_linear_precision=(False, False, False))

使用延迟缩放因子策略。使用前一次迭代的缩放因子,每隔 interval 重新计算一次,并记录 amax_history_len 步长的 amax 历史记录。

参数
  • margin (int, default = 0) – 缩放因子计算的裕度。

  • interval (int, default = 1) – 控制重新计算缩放因子的频率。

  • fp8_format ({Format.E4M3, Format.HYBRID}, default = Format.HYBRID) – 控制在前向和后向传递期间使用的 FP8 数据格式。

  • amax_history_len (int, default = 1024) – 用于缩放因子计算的 amax 历史窗口的长度。

  • amax_compute_algo ({'max', 'most_recent', Callable}, default = 'max') –

    用于为缩放因子计算选择 amax 值的算法。 有 2 种预定义的选择:max 选择历史窗口中最大的 amax,而 most_recent 始终选择最近看到的值。 或者,可以传递一个签名函数

    def amax_compute(amax_history: Tensor) -> Tensor
    

    其中 Tensor 是一种框架张量类型。

  • scaling_factor_compute_algo (Callable, default = None) –

    用于根据 amax 的值计算新缩放因子的算法。 它应该是一个签名函数

    def scaling_factor_compute(amax: Tensor,
                               old_scaling_factor: Tensor,
                               fp8_max: Tensor,
                               recipe: DelayedScaling) -> Tensor
    

    其中 Tensor 是一种框架张量类型。

  • override_linear_precision (Tuple(bool, bool, bool), default=(False, False, False)) – 在使用 FP8 时,是否以更高的精度执行 fpropdgradwgrad GEMM(分别为)。

  • reduce_amax (bool, default = True) – 默认情况下,如果初始化了 torch.distributed,则 FP8 张量的 amax 值会在 fp8_group(在 fp8_autocast 调用中指定)上进行缩减。 这可以使给定分布式组中的 amaxes 和缩放因子保持同步。 如果设置为 False,则会跳过此缩减,并且每个 GPU 都会维护本地 amaxes 和缩放因子。 为了确保在这种情况下,跨检查点边界的结果在数值上是相同的,所有等级都必须进行检查点以存储本地张量。

注释

  • 默认情况下(当 scaling_factor_compute_algo 保持为 None 时),缩放因子是使用公式从最终的 amax 值计算的

    FP8_MAX = maximum_representable_value(fp8_format)
    new_scaling_factor = (FP8_MAX / amax) / (2 ^ margin)