Transformer Engine
1.2.0dev-e10997b
选择版本
当前版本
旧版本
首页
入门
安装
准备工作
NGC 容器中的 Transformer Engine
pip - 来自 GitHub
其他准备工作
安装(稳定版本)
安装(开发版本)
安装(从源码)
入门
概览
让我们构建一个 Transformer 层!
了解 Transformer Engine
融合 TE 模块
启用 FP8
Python API 文档
通用 API
格式
DelayedScaling
特定框架 API
pyTorch
Linear
forward
set_tensor_parallel_group
LayerNorm
RMSNorm
LayerNormLinear
forward
set_tensor_parallel_group
LayerNormMLP
forward
set_tensor_parallel_group
DotProductAttention
forward
set_context_parallel_group
MultiheadAttention
forward
set_context_parallel_group
set_tensor_parallel_group
TransformerLayer
forward
set_context_parallel_group
set_tensor_parallel_group
InferenceParams
CudaRNGStatesTracker
add
fork
get_states
reset
set_states
fp8_autocast
fp8_model_init
checkpoint
onnx_export
Jax
MajorShardingType
ShardingType
TransformerLayerType
ShardingResource
fp8_autocast
update_collections
update_fp8_metas
LayerNorm
__call__
DenseGeneral
__call__
LayerNormDenseGeneral
__call__
LayerNormMLP
__call__
RelativePositionBiases
__call__
MultiHeadAttention
__call__
TransformerLayer
__call__
extend_logical_axis_rules
paddle
Linear
forward
LayerNorm
LayerNormLinear
forward
LayerNormMLP
forward
FusedScaleMaskSoftmax
forward
DotProductAttention
forward
MultiHeadAttention
forward
TransformerLayer
forward
fp8_autocast
recompute
示例和教程
将 FP8 与 Transformer Engine 一起使用
FP8 简介
结构
混合精度训练 - 快速入门
使用 FP8 进行混合精度训练
将 FP8 与 Transformer Engine 一起使用
FP8 食谱
FP8 自动转换
处理反向传播
精度
性能优化
多 GPU 训练
梯度累积融合
FP8 权重缓存
高级
C/C++ API
activation.h
void nvte_gelu
void nvte_dgelu
void nvte_geglu
void nvte_dgeglu
void nvte_relu
void nvte_drelu
void nvte_swiglu
void nvte_dswiglu
void nvte_reglu
void nvte_dreglu
cast.h
void nvte_fp8_quantize
void nvte_fp8_dequantize
gemm.h
void nvte_cublas_gemm
void nvte_cublas_atomic_gemm
fused_attn.h
enum NVTE_QKV_Layout
enumerator NVTE_SB3HD
enumerator NVTE_SBH3D
enumerator NVTE_SBHD_SB2HD
enumerator NVTE_SBHD_SBH2D
enumerator NVTE_SBHD_SBHD_SBHD
enumerator NVTE_BS3HD
enumerator NVTE_BSH3D
enumerator NVTE_BSHD_BS2HD
enumerator NVTE_BSHD_BSH2D
enumerator NVTE_BSHD_BSHD_BSHD
enumerator NVTE_T3HD
enumerator NVTE_TH3D
enumerator NVTE_THD_T2HD
enumerator NVTE_THD_TH2D
enumerator NVTE_THD_THD_THD
enum NVTE_QKV_Layout_Group
enumerator NVTE_3HD
enumerator NVTE_H3D
enumerator NVTE_HD_2HD
enumerator NVTE_HD_H2D
enumerator NVTE_HD_HD_HD
enum NVTE_QKV_Format
enumerator NVTE_SBHD
enumerator NVTE_BSHD
enumerator NVTE_THD
enum NVTE_Bias_Type
enumerator NVTE_NO_BIAS
enumerator NVTE_PRE_SCALE_BIAS
enumerator NVTE_POST_SCALE_BIAS
enumerator NVTE_ALIBI
enum NVTE_Mask_Type
enumerator NVTE_NO_MASK
enumerator NVTE_PADDING_MASK
enumerator NVTE_CAUSAL_MASK
enumerator NVTE_PADDING_CAUSAL_MASK
enum NVTE_Fused_Attn_Backend
enumerator NVTE_No_Backend
enumerator NVTE_F16_max512_seqlen
enumerator NVTE_F16_arbitrary_seqlen
enumerator NVTE_FP8
NVTE_QKV_Layout_Group nvte_get_qkv_layout_group
NVTE_QKV_Format nvte_get_qkv_format
NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend
void nvte_fused_attn_fwd_qkvpacked
void nvte_fused_attn_bwd_qkvpacked
void nvte_fused_attn_fwd_kvpacked
void nvte_fused_attn_bwd_kvpacked
void nvte_fused_attn_fwd
void nvte_fused_attn_bwd
layer_norm.h
void nvte_layernorm_fwd
void nvte_layernorm1p_fwd
void nvte_layernorm_bwd
void nvte_layernorm1p_bwd
rmsnorm.h
void nvte_rmsnorm_fwd
void nvte_rmsnorm_bwd
softmax.h
void nvte_scaled_softmax_forward
void nvte_scaled_softmax_backward
void nvte_scaled_masked_softmax_forward
void nvte_scaled_masked_softmax_backward
void nvte_scaled_upper_triang_masked_softmax_forward
void nvte_scaled_upper_triang_masked_softmax_backward
transformer_engine.h
typedef void *NVTETensor
enum NVTEDType
enumerator kNVTEByte
enumerator kNVTEInt32
enumerator kNVTEInt64
enumerator kNVTEFloat32
enumerator kNVTEFloat16
enumerator kNVTEBFloat16
enumerator kNVTEFloat8E4M3
enumerator kNVTEFloat8E5M2
enumerator kNVTENumTypes
NVTETensor nvte_create_tensor
void nvte_destroy_tensor
NVTEDType nvte_tensor_type
NVTEShape nvte_tensor_shape
void *nvte_tensor_data
float *nvte_tensor_amax
float *nvte_tensor_scale
float *nvte_tensor_scale_inv
void nvte_tensor_pack_create
void nvte_tensor_pack_destroy
struct NVTEShape
const size_t *data
size_t ndim
struct NVTETensorPack
NVTETensor tensors[MAX_SIZE]
size_t size = 0
static const int MAX_SIZE = 10
namespace transformer_engine
enum class DType
struct TensorWrapper
transpose.h
void nvte_cast_transpose
void nvte_transpose
void nvte_cast_transpose_dbias
void nvte_fp8_transpose_dbias
void nvte_cast_transpose_dbias_dgelu
void nvte_multi_cast_transpose
void nvte_dgeglu_cast_transpose
Transformer Engine
»
特定框架 API
查看页面源码
特定框架 API
¶
pyTorch
Linear
forward
set_tensor_parallel_group
LayerNorm
RMSNorm
LayerNormLinear
forward
set_tensor_parallel_group
LayerNormMLP
forward
set_tensor_parallel_group
DotProductAttention
forward
set_context_parallel_group
MultiheadAttention
forward
set_context_parallel_group
set_tensor_parallel_group
TransformerLayer
forward
set_context_parallel_group
set_tensor_parallel_group
InferenceParams
CudaRNGStatesTracker
add
fork
get_states
reset
set_states
fp8_autocast
fp8_model_init
checkpoint
onnx_export
Jax
MajorShardingType
ShardingType
TransformerLayerType
ShardingResource
fp8_autocast
update_collections
update_fp8_metas
LayerNorm
__call__
DenseGeneral
__call__
LayerNormDenseGeneral
__call__
LayerNormMLP
__call__
RelativePositionBiases
__call__
MultiHeadAttention
__call__
TransformerLayer
__call__
extend_logical_axis_rules
paddle
Linear
forward
LayerNorm
LayerNormLinear
forward
LayerNormMLP
forward
FusedScaleMaskSoftmax
forward
DotProductAttention
forward
MultiHeadAttention
forward
TransformerLayer
forward
fp8_autocast
recompute