通用 API¶
- class transformer_engine.common.recipe.Format¶
支持的 FP8 格式。
- 值
E4M3 – 所有 FP8 张量均为 e4m3 格式
E5M2 – 所有 FP8 张量均为 e5m2 格式
HYBRID – 前向传递中的 FP8 张量为 e4m3 格式,后向传递中的 FP8 张量为 e5m2 格式
- class transformer_engine.common.recipe.DelayedScaling(margin=0, interval=1, fp8_format=Format.E4M3, amax_history_len=1024, amax_compute_algo='max', scaling_factor_compute_algo=None, override_linear_precision=(False, False, False))¶
使用延迟缩放因子策略。使用前一次迭代的缩放因子,每隔 interval 重新计算一次,并记录 amax_history_len 步长的 amax 历史记录。
- 参数
margin (int, default = 0) – 缩放因子计算的裕度。
interval (int, default = 1) – 控制重新计算缩放因子的频率。
fp8_format ({Format.E4M3, Format.HYBRID}, default = Format.HYBRID) – 控制在前向和后向传递期间使用的 FP8 数据格式。
amax_history_len (int, default = 1024) – 用于缩放因子计算的 amax 历史窗口的长度。
amax_compute_algo ({'max', 'most_recent', Callable}, default = 'max') –
用于为缩放因子计算选择 amax 值的算法。 有 2 种预定义的选择:max 选择历史窗口中最大的 amax,而 most_recent 始终选择最近看到的值。 或者,可以传递一个签名函数
def amax_compute(amax_history: Tensor) -> Tensor
其中 Tensor 是一种框架张量类型。
scaling_factor_compute_algo (Callable, default = None) –
用于根据 amax 的值计算新缩放因子的算法。 它应该是一个签名函数
def scaling_factor_compute(amax: Tensor, old_scaling_factor: Tensor, fp8_max: Tensor, recipe: DelayedScaling) -> Tensor
其中 Tensor 是一种框架张量类型。
override_linear_precision (Tuple(bool, bool, bool), default=(False, False, False)) – 在使用 FP8 时,是否以更高的精度执行 fprop、dgrad 和 wgrad GEMM(分别为)。
reduce_amax (bool, default = True) – 默认情况下,如果初始化了 torch.distributed,则 FP8 张量的 amax 值会在 fp8_group(在 fp8_autocast 调用中指定)上进行缩减。 这可以使给定分布式组中的 amaxes 和缩放因子保持同步。 如果设置为 False,则会跳过此缩减,并且每个 GPU 都会维护本地 amaxes 和缩放因子。 为了确保在这种情况下,跨检查点边界的结果在数值上是相同的,所有等级都必须进行检查点以存储本地张量。
注释
默认情况下(当 scaling_factor_compute_algo 保持为 None 时),缩放因子是使用公式从最终的 amax 值计算的
FP8_MAX = maximum_representable_value(fp8_format) new_scaling_factor = (FP8_MAX / amax) / (2 ^ margin)