Jax¶
- class transformer_engine.jax.MajorShardingType¶
- class transformer_engine.jax.ShardingType¶
- class transformer_engine.jax.flax.TransformerLayerType¶
TransformerLayerType 是一个 Enum 类,用于指定 TransformerLayer 的类型
- 值
ENCODER – TransformerLayer 的编码器类型。
DECODER – TransformerLayer 的解码器类型。
- class transformer_engine.jax.ShardingResource(dp_resource=None, tp_resource=None)¶
- transformer_engine.jax.fp8_autocast(enabled: bool = False, fp8_recipe: Optional[transformer_engine.common.recipe.DelayedScaling] = None, mesh_resource: Optional[transformer_engine.jax.sharding.MeshResource] = None)¶
用于 FP8 使用的上下文管理器。
mesh_shape = (4, 2) dp_mesh_axis_name = 'data_parallel' tp_mesh_axis_name = 'tensor_parallel' devices = np.asarray(jax.devices()).reshape(*mesh_shape) with maps.Mesh(devices, (dp_mesh_axis_name, tp_mesh_axis_name)): mesh_resource=MeshResource(dp_mesh_axis_name, tp_mesh_axis_name) with fp8_autocast(enabled=True, mesh_resource=mesh_resource): rules = extend_logical_axis_rules(tuple()) transformer = TransformerLayer() with partitioning.axis_rules(rules): pjit(transformer.init, ...)(...)
注意
目前,我们在 recipe.DelayedScaling 中仅支持
margin
,fp8_format
,interval
,amax_history_len
和 :attr:amax_compute_algo
(值为 ‘max’ 和 ‘most_recent’)。recipe.DelayedScaling 中的其他参数将触发断言。- 参数
enabled (bool, 默认值 = False) – 是否启用 fp8
fp8_recipe (recipe.DelayedScaling, 默认值 = None) – 用于 FP8 训练的配置。
mesh_resource (MeshResource, 默认值 = None) – 指定用于数据并行和张量并行的网格轴。如果设置为 None,则不使用数据并行或张量并行。
- transformer_engine.jax.update_collections(new: Collection, original: Collection)¶
一个更新 Flax Collection 的辅助函数。
Collection = [dict, flax.core.frozen_dict.FrozenDict]
- 参数
new (Collection) – 包含新数据的 Collection。
original (Collection) – 基础 Collection。
- 返回值
outputs – 更新后的 Collection。
- 返回类型
Collection
- transformer_engine.jax.update_fp8_metas(state: Collection)¶
通过以下公式计算新的 fp8 比例因子及其倒数
sf = (fp8_max / amax) / (2 ^ margin) sf = sf if amax > 0.0, else original_scale updated_scale = sf if isfinite(amax), else original_scale) updated_scale_inv = 1/updated_scale
Collection = [dict, flax.core.frozen_dict.FrozenDict]
- 参数
state (Collection) – 包含 FP8 元数据的 Collection。
- 返回值
outputs – 包含更新后的 FP8 元数据的 Collection。
- 返回类型
Collection
- class transformer_engine.jax.flax.LayerNorm(epsilon=1e-6, layernorm_type='layernorm', **kwargs)¶
对输入的小批量数据应用层归一化。此模块支持两种归一化类型:普通层归一化和均方根层归一化。
普通层归一化如论文 Layer Normalization 中所述
\[y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta\]\(\gamma\) 和 \(\beta\) 是与每个输入样本大小相同的可学习仿射变换参数。
均方根层归一化 (RMSNorm) 如论文 Root Mean Square Layer Normalization 中所述
\[y = \frac{x}{ \mathrm{RMS}[x] + \epsilon} * \gamma\]\[RMS = \sqrt{\mathrm{E}[x^2]}\]\(\gamma\) 是与每个输入样本大小相同的可学习仿射变换参数。
- 参数
epsilon (float, 默认值 = 1e-6) – 添加到层归一化分母中的一个值,用于数值稳定性。
layernorm_type ({'layernorm', 'rmsnorm'}, 默认值 = 'layernorm') – 指示层归一化的类型。
zero_centered_gamma (bool, 默认值 = False) –
如果设置为 True,LayerNorm 公式变为
\[y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} * (1 + \gamma) + \beta\]此参数仅适用于 ‘layernorm’。 scale_init 的默认值也会改变。请参阅 scale_init。
scale_init (Initializer, 默认值 = None) – 用于初始化比例因子 \(\gamma\)。如果提供 None,则 scale_init 根据 zero_centered_gamma 的值设置。如果 zero_centered_gamma 设置为 True,则 scale_init 是 flax.linen.initializers.zeros。否则,scale_init 是 flax.linen.initializers.ones。它应该是一个可调用对象,接受三个参数 (jax.random.PRNGKey, shape, dtype)。
scale_axes (Tuple[str, ...], 默认值 = ('embed', )) – 用于通过相应的 mesh 对比例因子 \(\gamma\) 进行分片的轴名称。
bias_init (Initializer, 默认值 = flax.linen.initializers.zeros) – 用于初始化偏移因子 \(\beta\),仅在
layernorm_type='layernorm'
时使用。它应该是一个可调用对象,接受三个参数 (jax.random.PRNGKey, shape, dtype)。bias_axes (Tuple[str, ...], 默认值 = ('embed', )) – 用于通过相应的 mesh 对偏移因子 \(\beta\) 进行分片的轴名称。仅在
layernorm_type='layernorm'
时使用。
- 优化参数
dtype (jax.numpy.dtype, 默认值 = jax.numpy.float32) – 用于分配初始参数的数据类型。
transpose_batch_sequence (bool, 默认值 = False) – 指示输入张量的批大小和序列长度维度轴是否已调换。如果设置为 True,输入张量应为 (seqlen, batch, hidden),否则为 (batch, seqlen, hidden)。
- __call__(x: jax.numpy.ndarray)¶
对输入
inputs
应用层归一化。- 参数
inputs (jax.numpy.ndarray) – 输入张量。
- 返回值
outputs – 输出张量。
- 返回类型
jax.numpy.ndarray
- class transformer_engine.jax.flax.DenseGeneral(features, layernorm_type='layernorm', use_bias=False, **kwargs)¶
对输入数据应用线性变换 \(y = xA^T + b\)
- 参数
features (Union[Iterable[int], int]) – 每个输出样本的隐藏大小。
kernel_init (Initializer, 默认值 =) – flax.linen.initializers.variance_scaling(1.0, ‘fan_in’, ‘truncated_normal’) 用于初始化权重。它应该是一个可调用对象,接受三个参数 (jax.random.PRNGKey, shape, dtype)。
kernel_axes (Tuple[str, ...], 默认值 = ()) – 用于通过相应的 mesh 对权重进行分片的轴名称。
use_bias (bool, 默认值 = False) – 指示是否启用偏置。如果设置为 False,该层将不会学习加性偏置。
bias_init (Initializer, 默认值 = flax.linen.initializers.zeros) – 用于初始化偏置,仅在
use_bias=True
时使用。它应该是一个可调用对象,接受三个参数 (jax.random.PRNGKey, shape, dtype)。bias_axes (Tuple[str, ...], 默认值 = ()) – 用于通过相应的 mesh 对偏置进行分片的轴名称,仅在
use_bias=True
时使用。axis (Union[Iterable[int], int], 默认值 = -1) – 应用变换的整数轴元组。
- 优化参数
dtype (jax.numpy.dtype, 默认值 = jax.numpy.float32) – 用于分配初始参数的数据类型。
transpose_batch_sequence (bool, 默认值 = True) – 指示输入张量的批大小和序列长度维度轴是否已调换。如果设置为 True,输入张量应为 (seqlen, batch, hidden),否则为 (batch, seqlen, hidden)。
- __call__(inputs: Array)¶
对输入应用线性变换。
- 参数
inputs (jax.numpy.ndarray) – 输入张量。
- 返回值
outputs – 输出张量。
- 返回类型
jax.numpy.ndarray
- class transformer_engine.jax.flax.LayerNormDenseGeneral(features, layernorm_type='layernorm', epsilon=1e-6, use_bias=False, **kwargs)¶
对输入数据应用层归一化,然后应用线性变换。
- 参数
features (Union[Iterable[int], int]) – 每个输出样本的隐藏大小。
enable_layernorm (bool, 默认值 = True) – 指示是否在进行线性变换之前启用层归一化。
layernorm_type ({'layernorm', 'rmsnorm'}, 默认值 = 'layernorm') – 指示层归一化的类型。
epsilon (float, 默认值 = 1e-6) – 添加到层归一化分母中的一个值,用于数值稳定性。
zero_centered_gamma (bool, 默认值 = False) –
如果设置为 True,LayerNorm 公式变为
\[y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} * (1 + \gamma) + \beta\]此参数仅适用于 ‘layernorm’。 scale_init 的默认值也会改变。请参阅 scale_init
scale_init (Initializer, 默认值 = None) – 用于初始化比例因子 \(\gamma\)。如果提供 None,则 scale_init 根据 zero_centered_gamma 的值设置。如果 zero_centered_gamma 设置为 True,则 scale_init 是 flax.linen.initializers.zeros。否则,scale_init 是 flax.linen.initializers.ones。它应该是一个可调用对象,接受三个参数 (jax.random.PRNGKey, shape, dtype)。
scale_axes (Tuple[str, ...], 默认值 = ('embed', )) – 用于通过相应的 mesh 对比例因子 \(\gamma\) 进行分片的轴名称,仅在
enable_layernorm=True
时使用。ln_bias_init (Initializer, 默认值 = flax.linen.initializers.zeros) – 用于初始化偏移因子 \(\beta\),仅在
enable_layernorm=True
和layernorm_type='layernorm'
时使用。它应该是一个可调用对象,接受三个参数 (jax.random.PRNGKey, shape, dtype)。ln_bias_axes (Tuple[str, ...], 默认值 = ('embed', )) – 用于通过相应的 mesh 对偏移因子 \(\beta\) 进行分片的轴名称。仅在
enable_layernorm=True
和layernorm_type='layernorm'
时使用。kernel_init (Initializer, 默认值 =) – flax.linen.initializers.variance_scaling(1.0, ‘fan_in’, ‘truncated_normal’) 用于初始化权重。它应该是一个可调用对象,接受三个参数 (jax.random.PRNGKey, shape, dtype)。
kernel_axes (Tuple[str, ...], 默认值 = ()) – 用于通过相应的 mesh 对权重进行分片的轴名称。
use_bias (bool, 默认值 = False) – 指示是否启用偏置。如果设置为 False,该层将不会学习加性偏置。
bias_init (Initializer, 默认值 = flax.linen.initializers.zeros) – 用于初始化偏置,仅在
use_bias=True
时使用。它应该是一个可调用对象,接受三个参数 (jax.random.PRNGKey, shape, dtype)。bias_axes (Tuple[str, ...], 默认值 = ()) – 用于通过相应的 mesh 对偏置进行分片的轴名称,仅在
use_bias=True
时使用。return_layernorm_output (bool, 默认值 = True) – 指示是否返回层归一化的输出。如果设置为 False,则在输出中的第二个张量返回 None。
axis (Union[Iterable[int], int], 默认值 = -1) – 应用变换的整数轴元组。
- 优化参数
dtype (jax.numpy.dtype, 默认值 = jax.numpy.float32) – 用于分配初始参数的数据类型。
transpose_batch_sequence (bool, 默认值 = True) – 指示输入张量的批大小和序列长度维度轴是否已调换。如果设置为 True,输入张量应为 (seqlen, batch, hidden),否则为 (batch, seqlen, hidden)。
depth_scaling (float, 默认值 = None) – 用于缩放 DenseGeneral 输出的因子。它应该是一个浮点数或 None。当设置为 None 时,不应用缩放。
- __call__(inputs: Array)¶
对输入应用层归一化,然后应用线性变换。
- 参数
inputs (jax.numpy.ndarray) – 输入张量。
- 返回值
outputs (jax.numpy.ndarray) – 输出张量。
ln_outputs (jax.numpy.ndarray) – 层归一化的输出张量。如果
return_layernorm_output=False
,则此项为 None。
- class transformer_engine.jax.flax.LayerNormMLP(intermediate_dim=2048, layernorm_type='layernorm', epsilon=1e-6, use_bias=False, **kwargs)¶
对输入应用层归一化,然后应用 MLP 模块,该模块由 2 个连续的线性变换组成,中间由给定的激活函数分隔。
- 参数
intermediate_dim (int, 默认值 = 2048) – 输入样本投影到的中间大小。
enable_layernorm (bool, 默认值 = True) – 指示是否在进行线性变换之前启用层归一化。
layernorm_type ({'layernorm', 'rmsnorm'}, 默认值 = 'layernorm') – 指示层归一化的类型。
epsilon (float, 默认值 = 1e-6) – 添加到层归一化分母中的一个值,用于数值稳定性。
zero_centered_gamma (bool, 默认值 = False) –
如果设置为 True,LayerNorm 公式变为
\[y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} * (1 + \gamma) + \beta\]此参数仅适用于 ‘layernorm’。 scale_init 的默认值也会改变。请参阅 scale_init。
scale_init (Initializer, 默认值 = None) – 用于初始化比例因子 \(\gamma\)。如果提供 None,则 scale_init 根据 zero_centered_gamma 的值设置。如果 zero_centered_gamma 设置为 True,则 scale_init 是 flax.linen.initializers.zeros。否则,scale_init 是 flax.linen.initializers.ones。它应该是一个可调用对象,接受三个参数 (jax.random.PRNGKey, shape, dtype)。
scale_axes (Tuple[str, ...], 默认值 = ('embed', )) – 用于通过相应的 mesh 对比例因子 \(\gamma\) 进行分片的轴名称,仅在
enable_layernorm=True
时使用。ln_bias_init (Initializer, 默认值 = flax.linen.initializers.zeros) – 用于初始化偏移因子 \(\beta\),仅在
enable_layernorm=True
和layernorm_type='layernorm'
时使用。它应该是一个可调用对象,接受三个参数 (jax.random.PRNGKey, shape, dtype)。ln_bias_axes (Tuple[str, ...], 默认值 = ('embed', )) – 用于通过相应的 mesh 对偏移因子 \(\beta\) 进行分片的轴名称。仅在
enable_layernorm=True
和layernorm_type='layernorm'
时使用。kernel_init (Initializer, 默认值 =) – flax.linen.initializers.variance_scaling(1.0, ‘fan_in’, ‘truncated_normal’) 用于初始化两次线性变换的权重。它应该是一个可调用对象,接受三个参数 (jax.random.PRNGKey, shape, dtype)。
kernel_axes_1 (Tuple[str, ...], 默认值 = ('embed', 'act', 'mlp')) – 用于通过相应的 mesh 对第一次线性变换的权重进行分片的轴名称。
kernel_axes_2 (Tuple[str, ...], 默认值 = ('mlp', 'embed')) – 用于通过相应的 mesh 对第二次线性变换的权重进行分片的轴名称。
use_bias (bool, 默认值 = False) – 指示是否启用偏置。如果设置为 False,该层将不会学习加性偏置。
bias_init (Initializer, 默认值 = flax.linen.initializers.zeros) – 用于初始化偏置,仅在
use_bias=True
时使用。它应该是一个可调用对象,接受三个参数 (jax.random.PRNGKey, shape, dtype)。bias_axes_1 (Tuple[str, ...], 默认值 = ('mlp',)) – 用于通过相应的 mesh 对第一次线性变换的偏置进行分片的轴名称。仅在
use_bias=True
时使用。bias_axes_2 (Tuple[str, ...], 默认值 = ('embed',)) – 用于通过相应的 mesh 对第二次线性变换的偏置进行分片的轴名称。仅在
use_bias=True
时使用。return_layernorm_output (bool, 默认值 = True) – 指示是否返回层归一化的输出。如果设置为 False,则在输出中的第二个张量返回 None。
activations (Sequence[Union[str, Callable]], 默认值 = ('relu',)) – 在第一次线性变换后应用的激活函数序列。每个激活函数都有自己的变换层。
intermediate_dropout_rng_name (str, 默认值 = 'dropout') – 通过 flax.linen.Module.apply 给定的 RNG 中用于生成 Dropout mask 的键。
intermediate_dropout_rate (float, 默认值 = 0.1) – 在
activations
之后的 dropout 操作的 dropout 概率。intermediate_hidden_dropout_dims (Sequence[int], 默认值 = ()) – 隐藏维度中将共享相同 dropout mask 的维度。
axis (Union[Iterable[int], int], 默认值 = -1) – 应用变换的整数轴元组。
- 优化参数
dtype (jax.numpy.dtype, 默认值 = jax.numpy.float32) – 用于分配初始参数的数据类型。
transpose_batch_sequence (bool, 默认值 = True) – 指示输入张量的批大小和序列长度维度轴是否已调换。如果设置为 True,输入张量应为 (seqlen, batch, hidden),否则为 (batch, seqlen, hidden)。
- __call__(inputs: Array, deterministic: bool = False)¶
对输入应用层归一化,然后应用前馈网络 (MLP 块)。
- 参数
inputs (jax.numpy.ndarray) – 输入张量。
deterministic (bool, 默认值 = False) – 如果设置为 True,则禁用 dropout 操作。
- 返回值
outputs (jax.numpy.ndarray) – 输出张量。
ln_outputs (jax.numpy.ndarray) – 层归一化的输出张量。如果
return_layernorm_output=False
,则此项为 None。
- class transformer_engine.jax.flax.RelativePositionBiases(num_buckets, max_distance, num_heads, **kwargs)¶
将 T5 风格的相对位置嵌入应用到注意力对数 (attention logits)。
- 参数
num_buckets (int) – 用于将 key 和 query 位置之间的距离分桶的桶数。
max_distance (int) – 在所有距离都归入最后一个距离桶之前的最大距离。
num_attention_heads (int) – Transformer 层中的注意力头数。
embedding_init (Initializer, default = flax.linen.linear.default_embed_init) – 用于初始化相对位置嵌入表。
embedding_axes (Tuple[str, ...], default = ('heads', 'relpos_buckets')) – 用于与对应网格一起分割嵌入注意力偏置的轴名称。
- 优化参数
dtype (jax.numpy.dtype, 默认值 = jax.numpy.float32) – 用于分配初始参数的数据类型。
- __call__(q_seqlen, k_seqlen, bidirectional=True)¶
生成相对位置嵌入注意力偏置。
- 参数
q_seqlen (int) – 查询序列的长度。
k_seqlen (int) – 键序列的长度。
bidirectional (bool, default = True) – 指示是否允许正向记忆-查询相对位置嵌入。
- 返回值
output – 一个形状为 (1, num_attention_heads, q_seqlen, k_seqlen) 的注意力偏置。
- 返回类型
jax.numpy.ndarray
- class transformer_engine.jax.flax.MultiHeadAttention(head_dim, num_heads, **kwargs)¶
多头注意力 (MHA),包括查询 (Query)、键 (Key)、值 (Value) 和输出投影。
注意
当
attn_mask_type
设置为 “causal” 时,参数mask
将被忽略。- 参数
head_dim (int) – 每个注意力头的隐藏维度。
num_heads (int) – 注意力头的数量。
dropout_rate (float, default = 0.0) – 多头注意力操作中的 Dropout 概率。
dropout_rng_name (str, default = 'dropout') – 通过 flax.linen.Module.apply 在给定的 RNG 中用于在核心注意力中生成 Dropout 掩码的键。
layernorm_type ({'layernorm', 'rmsnorm'}, 默认值 = 'layernorm') – 指示层归一化的类型。
layernorm_epsilon (float, default = 1e-6) – 添加到层归一化分母中的一个值,用于数值稳定性。
zero_centered_gamma (bool, 默认值 = False) –
如果设置为 True,LayerNorm 公式变为
\[y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} * (1 + \gamma) + \beta\]此参数仅适用于 'layernorm'。
kernel_init (Initializer, default =) – flax.linen.initializers.variance_scaling(1.0, ‘fan_in’, ‘normal’) 用于初始化 QKV 和输出投影权重。它应该是一个可调用对象,接受三个参数 (jax.random.PRNGKey, shape, dtype)。
use_bias (bool, default = False) – 指示是否启用 QKVO 投影的偏置偏移。如果设置为 False,层将不会学习加性偏置。
bias_init (Initializer, default = flax.linen.initializers.zeros) – 用于初始化 QKVO 投影的偏置,仅在
use_bias=True
时使用。它应该是一个可调用对象,接受三个参数 (jax.random.PRNGKey, shape, dtype)。apply_residual_connection_post_layernorm (bool, default = False) – 指示是否将残差连接应用于层归一化的输出。
output_layernorm (bool, default = False) – 指示是否在 MHA 结束时应用层归一化。
attn_mask_type ({'causal', 'padding'}, default = 'causal') – 传递给 softmax 操作的注意力掩码类型。在 v0.10.0 中引入。
- 优化参数
dtype (jax.numpy.dtype, 默认值 = jax.numpy.float32) – 用于分配初始参数的数据类型。
fuse_qkv (bool, default = True) – 如果设置为 True,此模块为自注意力中的查询-键-值和交叉注意力中的键-值暴露一个单一的融合参数。
transpose_batch_sequence (bool, default = True) – 指示输入张量的批次和序列长度维度轴是否已交换。如果设置为 True,输入张量应为 (seqlen, batch, hidden) 形状,否则为 (batch, seqlen, hidden) 形状。
scale_attn_logits (bool, default = False) – 指示是否缩放注意力 logits。如果设置为 True,则计算 \(\frac{Q}{\sqrt{head_dim}*K}\),否则计算 \(Q*K\)
scaled_query_init (bool, default = True) – 初始化时是否通过 \(\sqrt{head_dim}\) 缩放 WQ。
float32_logits (bool, default = False) – 是否在 float32 中计算注意力 logits。
- __call__(inputs_q: Array, inputs_kv: Array, mask: Optional[Array] = None, bias: Optional[Array] = None, *, decode: bool = False, deterministic: bool = False)¶
MultiHeadAttention 层:[查询、键、值投影] -> 点积注意力 -> 输出投影。
- 参数
inputs_q (jax.numpy.ndarray) – 用于查询投影的输入张量。
inputs_kv (jax.numpy.ndarray) – 用于键/值投影的输入张量。
mask (jax.numpy.ndarray, default = None) – 用于屏蔽自注意力 softmax 输入的布尔张量。
bias (jax.numpy.ndarray, default = None) – 用于偏移自注意力 softmax 输入的张量。
* –
decode (bool,default = False) – 指示是否准备和使用自回归缓存。
deterministic (bool,default = False) – 如果设置为 True,则禁用 dropout 层。
- 返回值
outputs – 输出张量。
- 返回类型
jax.numpy.ndarray
- class transformer_engine.jax.flax.TransformerLayer(hidden_size=512, mlp_hidden_size=2048, num_attention_heads=8, **kwargs)¶
TransformerLayer 由相对位置嵌入、注意力块和前馈网络 (MLP) 组成。此标准层基于论文“Attention Is All You Need”。
注意
当
self_attn_mask_type
设置为 “causal” 时,参数attention_mask
将被忽略。- 参数
hidden_size (int, default = 512) – 每个输入样本的隐藏大小。
mlp_hidden_size (int, default = 2048) – 输入样本被投影到的中间大小。
num_attention_heads (int, default = 8) – Transformer 层中的注意力头数量。
layernorm_type ({'layernorm', 'rmsnorm'}, 默认值 = 'layernorm') – 指示层归一化的类型。
layernorm_epsilon (float, default = 1e-6) – 添加到层归一化分母中的一个值,用于数值稳定性。
zero_centered_gamma (bool, 默认值 = False) –
如果设置为 True,LayerNorm 公式变为
\[y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} * (1 + \gamma) + \beta\]此参数仅适用于 'layernorm'。
hidden_dropout (float, default = 0.1) – FC2 层后的 dropout 操作的 dropout 概率。
hidden_dropout_dims (Sequence[int], default = ()) – 将共享相同隐藏层 dropout 掩码的维度。
attention_dropout (float, default = 0.1) – 多头注意力操作中的 dropout 概率。
intermediate_dropout (float, default = 0.1) – FC1 层后的 dropout 操作的 dropout 概率。
intermediate_dropout_dims (Sequence[int], default = ()) – FC1 层后将共享相同隐藏层 dropout 掩码的维度。
dropout_rng_name (str, default = 'dropout') – 通过 flax.linen.Module.apply 在给定的 RNG 中用于在多头注意力中生成 Dropout 掩码的键。
mha_kernel_init (Initializer, default =) – flax.linen.initializers.variance_scaling(1.0, ‘fan_in’, ‘normal’) 用于初始化 QKV 和输出投影的权重。它应该是一个可调用对象,接受三个参数 (jax.random.PRNGKey, shape, dtype)。
mlp_kernel_init (Initializer, default =) – flax.linen.initializers.variance_scaling(1.0, ‘fan_in’, ‘truncated_normal’) 用于初始化 FC1 和 FC2 层的权重。它应该是一个可调用对象,接受三个参数 (jax.random.PRNGKey, shape, dtype)。
mlp_activations (Sequence[str], default = ('relu', )) – 在第一个线性变换后应用的激活函数序列。每个激活函数都有自己的变换层。
use_bias (bool, default = False) – 指示是否启用 QKVO 投影、FC1 和 FC2 的偏置偏移。如果设置为 False,层将不会学习加性偏置。
bias_init (Initializer, default = flax.linen.initializers.zeros) – 用于初始化 QKVO 投影、FC1 和 FC2 的偏置。仅在
use_bias=True
时使用。它应该是一个可调用对象,接受三个参数 (jax.random.PRNGKey, shape, dtype)。apply_residual_connection_post_layernorm (bool, default = False) – 如果设置为 True,则残差连接取自层归一化的输出(默认为取自层归一化的输入)。
output_layernorm (bool, default = False) – 如果设置为 True,则在输出侧、最终 dropout-add 之后应用层归一化。默认行为是在输入侧、QKV 变换之前应用层归一化。
float32_attention_logits (bool, default = False) – 如果设置为 True,则在 jax.numpy.float32 中执行注意力 logits。
layer_type (TransformerLayerType, default = TransformerLayerType.ENCODER) – 如果设置为 TransformerLayerType.DECODER,则在自注意力后添加一个额外的交叉注意力块。这可以与 TransformerLayerType.ENCODER 选项结合使用,用于像 T5 Transformer 这样的结构。
self_attn_mask_type ({'causal', 'padding'}, default = 'causal') – 传递给 softmax 操作的注意力掩码类型。在 v0.10.0 中引入。
enable_relative_embedding (bool, default = True) – 是否启用相对位置嵌入作为注意力 logits 的偏移。
relative_embedding (flax.linen.Module, default = None) – 用于相对位置嵌入执行的模块,仅在
enable_relative_embedding=True
时使用。默认为 None,如果enable_relative_embedding=True
,将创建一个 RelativePositionBiases 实例。默认值:RelativePositionBiases( num_buckets=32, max_distance=128, num_attention_heads=self.num_attention_heads, dtype=self.dtype, embedding_init=flax.linen.initializers.variance_scaling(1.0, ‘fan_avg’, ‘uniform’), name=’relpos_bias’)
- 优化参数
dtype (jax.numpy.dtype, 默认值 = jax.numpy.float32) – 用于分配初始参数的数据类型。
drop_path (float, default = 0.0) – 当 > 0.0 时,在残差块的主路径中对每个样本应用随机深度。
fuse_qkv_params (bool, default = True) – 如果设置为 True,TransformerLayer 模块为自注意力中的查询-键-值和交叉注意力中的键-值暴露一个单一的融合参数。
transpose_batch_sequence (bool, default = False) – 指示输入张量的批次和序列长度维度轴是否已交换。如果设置为 True,输入张量应为 (seqlen, batch, hidden) 形状,否则为 (batch, seqlen, hidden) 形状。
scale_attn_logits (bool, default = False) – 指示是否缩放注意力 logits。如果设置为 True,则计算 \(\frac{Q}{\sqrt{head_dim}*K}\),否则计算 \(Q*K\)
scaled_query_init (bool, default = True) – 初始化时是否通过 \(\sqrt{head_dim}\) 缩放 WQ。
- __call__(inputs: Array, encoded: Array = None, attention_mask: Array = None, encoder_decoder_mask: Array = None, deterministic: bool = False, decode: bool = False, max_decode_length: bool = None)¶
Transformer 层:注意力块和前馈网络 (MLP)
- 参数
inputs (jax.numpy.ndarray) – 输入张量。
encoded (jax.numpy.ndarray, default = None) – 如果使用
layer_type=TransformerLayerType.DECODER
,则将编码器块的输出张量馈送到解码器块中。attention_mask (jax.numpy.ndarray, default = None) – 用于屏蔽自注意力 softmax 输入的布尔张量。
encoder_decoder_mask (jax.numpy.ndarray, default = None) – 当
layer_type=TransformerLayerType.DECODER
时,用于屏蔽交叉注意力 softmax 输入的布尔张量。deterministic (bool, default = False) – 如果设置为 True,则禁用 dropout 层。
decode (bool,default = False) – 指示是否在多头注意力 (MHA) 中准备和使用自回归缓存。
max_decode_length (bool, default = None) – 当
layer_type=TransformerLayerType.DECODER
且enable_relative_embedding=True
时,生成相对位置嵌入偏置的最大长度。
- 返回值
outputs – 输出张量。
- 返回类型
jax.numpy.ndarray
- transformer_engine.jax.flax.extend_logical_axis_rules(rules: LogicalRules)¶
使用预定义的 TransformerLayer 逻辑轴规则扩展给定的 Flax 逻辑轴规则。
注意
我们目前仅支持单 GPU 训练、数据并行训练和 1D 分片张量并行训练的逻辑轴规则。有关 1D 分片张量并行,请参阅Megatron-LM 张量并行论文中的图 3。
警告
请确保在调用此函数之前已通过 fp8_autocast 设置 ShardingResource。
注意
此函数仅在使用 TransformerLayer 时需要。对于其他模块,例如 DenseGeneral,请正确设置核和偏置的轴。
- 参数
rules (Sequence[Tuple[str, Union[str, None]]]) – 要扩展的基础 Flax 逻辑轴规则。
- 返回值
extended_rules – 扩展后的 Flax 逻辑轴规则。
- 返回类型
Sequence[Tuple[str, Union[str, None]]]