Jax

class transformer_engine.jax.MajorShardingType
class transformer_engine.jax.ShardingType
class transformer_engine.jax.flax.TransformerLayerType

TransformerLayerType 是一个 Enum 类,用于指定 TransformerLayer 的类型

  • ENCODER – TransformerLayer 的编码器类型。

  • DECODER – TransformerLayer 的解码器类型。

class transformer_engine.jax.ShardingResource(dp_resource=None, tp_resource=None)
transformer_engine.jax.fp8_autocast(enabled: bool = False, fp8_recipe: Optional[transformer_engine.common.recipe.DelayedScaling] = None, mesh_resource: Optional[transformer_engine.jax.sharding.MeshResource] = None)

用于 FP8 使用的上下文管理器。

mesh_shape = (4, 2)
dp_mesh_axis_name = 'data_parallel'
tp_mesh_axis_name = 'tensor_parallel'
devices = np.asarray(jax.devices()).reshape(*mesh_shape)

with maps.Mesh(devices, (dp_mesh_axis_name, tp_mesh_axis_name)):
    mesh_resource=MeshResource(dp_mesh_axis_name, tp_mesh_axis_name)

    with fp8_autocast(enabled=True, mesh_resource=mesh_resource):
        rules = extend_logical_axis_rules(tuple())
        transformer = TransformerLayer()

        with partitioning.axis_rules(rules):
            pjit(transformer.init, ...)(...)

注意

目前,我们在 recipe.DelayedScaling 中仅支持 margin, fp8_format, interval, amax_history_len 和 :attr:amax_compute_algo (值为 ‘max’ 和 ‘most_recent’)。recipe.DelayedScaling 中的其他参数将触发断言。

参数
  • enabled (bool, 默认值 = False) – 是否启用 fp8

  • fp8_recipe (recipe.DelayedScaling, 默认值 = None) – 用于 FP8 训练的配置。

  • mesh_resource (MeshResource, 默认值 = None) – 指定用于数据并行和张量并行的网格轴。如果设置为 None,则不使用数据并行或张量并行。

transformer_engine.jax.update_collections(new: Collection, original: Collection)

一个更新 Flax Collection 的辅助函数。

Collection = [dict, flax.core.frozen_dict.FrozenDict]

参数
  • new (Collection) – 包含新数据的 Collection。

  • original (Collection) – 基础 Collection。

返回值

outputs – 更新后的 Collection。

返回类型

Collection

transformer_engine.jax.update_fp8_metas(state: Collection)

通过以下公式计算新的 fp8 比例因子及其倒数

sf = (fp8_max / amax) / (2 ^ margin)
sf = sf if amax > 0.0, else original_scale
updated_scale = sf if isfinite(amax), else original_scale)
updated_scale_inv = 1/updated_scale

Collection = [dict, flax.core.frozen_dict.FrozenDict]

参数

state (Collection) – 包含 FP8 元数据的 Collection。

返回值

outputs – 包含更新后的 FP8 元数据的 Collection。

返回类型

Collection

class transformer_engine.jax.flax.LayerNorm(epsilon=1e-6, layernorm_type='layernorm', **kwargs)

对输入的小批量数据应用层归一化。此模块支持两种归一化类型:普通层归一化和均方根层归一化。

普通层归一化如论文 Layer Normalization 中所述

\[y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta\]

\(\gamma\)\(\beta\) 是与每个输入样本大小相同的可学习仿射变换参数。

均方根层归一化 (RMSNorm) 如论文 Root Mean Square Layer Normalization 中所述

\[y = \frac{x}{ \mathrm{RMS}[x] + \epsilon} * \gamma\]
\[RMS = \sqrt{\mathrm{E}[x^2]}\]

\(\gamma\) 是与每个输入样本大小相同的可学习仿射变换参数。

参数
  • epsilon (float, 默认值 = 1e-6) – 添加到层归一化分母中的一个值,用于数值稳定性。

  • layernorm_type ({'layernorm', 'rmsnorm'}, 默认值 = 'layernorm') – 指示层归一化的类型。

  • zero_centered_gamma (bool, 默认值 = False) –

    如果设置为 True,LayerNorm 公式变为

    \[y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} * (1 + \gamma) + \beta\]

    此参数仅适用于 ‘layernorm’。 scale_init 的默认值也会改变。请参阅 scale_init

  • scale_init (Initializer, 默认值 = None) – 用于初始化比例因子 \(\gamma\)。如果提供 None,则 scale_init 根据 zero_centered_gamma 的值设置。如果 zero_centered_gamma 设置为 True,则 scale_init 是 flax.linen.initializers.zeros。否则,scale_init 是 flax.linen.initializers.ones。它应该是一个可调用对象,接受三个参数 (jax.random.PRNGKey, shape, dtype)。

  • scale_axes (Tuple[str, ...], 默认值 = ('embed', )) – 用于通过相应的 mesh 对比例因子 \(\gamma\) 进行分片的轴名称。

  • bias_init (Initializer, 默认值 = flax.linen.initializers.zeros) – 用于初始化偏移因子 \(\beta\),仅在 layernorm_type='layernorm' 时使用。它应该是一个可调用对象,接受三个参数 (jax.random.PRNGKey, shape, dtype)。

  • bias_axes (Tuple[str, ...], 默认值 = ('embed', )) – 用于通过相应的 mesh 对偏移因子 \(\beta\) 进行分片的轴名称。仅在 layernorm_type='layernorm' 时使用。

优化参数
  • dtype (jax.numpy.dtype, 默认值 = jax.numpy.float32) – 用于分配初始参数的数据类型。

  • transpose_batch_sequence (bool, 默认值 = False) – 指示输入张量的批大小和序列长度维度轴是否已调换。如果设置为 True,输入张量应为 (seqlen, batch, hidden),否则为 (batch, seqlen, hidden)。

__call__(x: jax.numpy.ndarray)

对输入 inputs 应用层归一化。

参数

inputs (jax.numpy.ndarray) – 输入张量。

返回值

outputs – 输出张量。

返回类型

jax.numpy.ndarray

class transformer_engine.jax.flax.DenseGeneral(features, layernorm_type='layernorm', use_bias=False, **kwargs)

对输入数据应用线性变换 \(y = xA^T + b\)

参数
  • features (Union[Iterable[int], int]) – 每个输出样本的隐藏大小。

  • kernel_init (Initializer, 默认值 =) – flax.linen.initializers.variance_scaling(1.0, ‘fan_in’, ‘truncated_normal’) 用于初始化权重。它应该是一个可调用对象,接受三个参数 (jax.random.PRNGKey, shape, dtype)。

  • kernel_axes (Tuple[str, ...], 默认值 = ()) – 用于通过相应的 mesh 对权重进行分片的轴名称。

  • use_bias (bool, 默认值 = False) – 指示是否启用偏置。如果设置为 False,该层将不会学习加性偏置。

  • bias_init (Initializer, 默认值 = flax.linen.initializers.zeros) – 用于初始化偏置,仅在 use_bias=True 时使用。它应该是一个可调用对象,接受三个参数 (jax.random.PRNGKey, shape, dtype)。

  • bias_axes (Tuple[str, ...], 默认值 = ()) – 用于通过相应的 mesh 对偏置进行分片的轴名称,仅在 use_bias=True 时使用。

  • axis (Union[Iterable[int], int], 默认值 = -1) – 应用变换的整数轴元组。

优化参数
  • dtype (jax.numpy.dtype, 默认值 = jax.numpy.float32) – 用于分配初始参数的数据类型。

  • transpose_batch_sequence (bool, 默认值 = True) – 指示输入张量的批大小和序列长度维度轴是否已调换。如果设置为 True,输入张量应为 (seqlen, batch, hidden),否则为 (batch, seqlen, hidden)。

__call__(inputs: Array)

对输入应用线性变换。

参数

inputs (jax.numpy.ndarray) – 输入张量。

返回值

outputs – 输出张量。

返回类型

jax.numpy.ndarray

class transformer_engine.jax.flax.LayerNormDenseGeneral(features, layernorm_type='layernorm', epsilon=1e-6, use_bias=False, **kwargs)

对输入数据应用层归一化,然后应用线性变换。

参数
  • features (Union[Iterable[int], int]) – 每个输出样本的隐藏大小。

  • enable_layernorm (bool, 默认值 = True) – 指示是否在进行线性变换之前启用层归一化。

  • layernorm_type ({'layernorm', 'rmsnorm'}, 默认值 = 'layernorm') – 指示层归一化的类型。

  • epsilon (float, 默认值 = 1e-6) – 添加到层归一化分母中的一个值,用于数值稳定性。

  • zero_centered_gamma (bool, 默认值 = False) –

    如果设置为 True,LayerNorm 公式变为

    \[y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} * (1 + \gamma) + \beta\]

    此参数仅适用于 ‘layernorm’。 scale_init 的默认值也会改变。请参阅 scale_init

  • scale_init (Initializer, 默认值 = None) – 用于初始化比例因子 \(\gamma\)。如果提供 None,则 scale_init 根据 zero_centered_gamma 的值设置。如果 zero_centered_gamma 设置为 True,则 scale_init 是 flax.linen.initializers.zeros。否则,scale_init 是 flax.linen.initializers.ones。它应该是一个可调用对象,接受三个参数 (jax.random.PRNGKey, shape, dtype)。

  • scale_axes (Tuple[str, ...], 默认值 = ('embed', )) – 用于通过相应的 mesh 对比例因子 \(\gamma\) 进行分片的轴名称,仅在 enable_layernorm=True 时使用。

  • ln_bias_init (Initializer, 默认值 = flax.linen.initializers.zeros) – 用于初始化偏移因子 \(\beta\),仅在 enable_layernorm=Truelayernorm_type='layernorm' 时使用。它应该是一个可调用对象,接受三个参数 (jax.random.PRNGKey, shape, dtype)。

  • ln_bias_axes (Tuple[str, ...], 默认值 = ('embed', )) – 用于通过相应的 mesh 对偏移因子 \(\beta\) 进行分片的轴名称。仅在 enable_layernorm=Truelayernorm_type='layernorm' 时使用。

  • kernel_init (Initializer, 默认值 =) – flax.linen.initializers.variance_scaling(1.0, ‘fan_in’, ‘truncated_normal’) 用于初始化权重。它应该是一个可调用对象,接受三个参数 (jax.random.PRNGKey, shape, dtype)。

  • kernel_axes (Tuple[str, ...], 默认值 = ()) – 用于通过相应的 mesh 对权重进行分片的轴名称。

  • use_bias (bool, 默认值 = False) – 指示是否启用偏置。如果设置为 False,该层将不会学习加性偏置。

  • bias_init (Initializer, 默认值 = flax.linen.initializers.zeros) – 用于初始化偏置,仅在 use_bias=True 时使用。它应该是一个可调用对象,接受三个参数 (jax.random.PRNGKey, shape, dtype)。

  • bias_axes (Tuple[str, ...], 默认值 = ()) – 用于通过相应的 mesh 对偏置进行分片的轴名称,仅在 use_bias=True 时使用。

  • return_layernorm_output (bool, 默认值 = True) – 指示是否返回层归一化的输出。如果设置为 False,则在输出中的第二个张量返回 None。

  • axis (Union[Iterable[int], int], 默认值 = -1) – 应用变换的整数轴元组。

优化参数
  • dtype (jax.numpy.dtype, 默认值 = jax.numpy.float32) – 用于分配初始参数的数据类型。

  • transpose_batch_sequence (bool, 默认值 = True) – 指示输入张量的批大小和序列长度维度轴是否已调换。如果设置为 True,输入张量应为 (seqlen, batch, hidden),否则为 (batch, seqlen, hidden)。

  • depth_scaling (float, 默认值 = None) – 用于缩放 DenseGeneral 输出的因子。它应该是一个浮点数或 None。当设置为 None 时,不应用缩放。

__call__(inputs: Array)

对输入应用层归一化,然后应用线性变换。

参数

inputs (jax.numpy.ndarray) – 输入张量。

返回值

  • outputs (jax.numpy.ndarray) – 输出张量。

  • ln_outputs (jax.numpy.ndarray) – 层归一化的输出张量。如果 return_layernorm_output=False,则此项为 None。

class transformer_engine.jax.flax.LayerNormMLP(intermediate_dim=2048, layernorm_type='layernorm', epsilon=1e-6, use_bias=False, **kwargs)

对输入应用层归一化,然后应用 MLP 模块,该模块由 2 个连续的线性变换组成,中间由给定的激活函数分隔。

参数
  • intermediate_dim (int, 默认值 = 2048) – 输入样本投影到的中间大小。

  • enable_layernorm (bool, 默认值 = True) – 指示是否在进行线性变换之前启用层归一化。

  • layernorm_type ({'layernorm', 'rmsnorm'}, 默认值 = 'layernorm') – 指示层归一化的类型。

  • epsilon (float, 默认值 = 1e-6) – 添加到层归一化分母中的一个值,用于数值稳定性。

  • zero_centered_gamma (bool, 默认值 = False) –

    如果设置为 True,LayerNorm 公式变为

    \[y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} * (1 + \gamma) + \beta\]

    此参数仅适用于 ‘layernorm’。 scale_init 的默认值也会改变。请参阅 scale_init

  • scale_init (Initializer, 默认值 = None) – 用于初始化比例因子 \(\gamma\)。如果提供 None,则 scale_init 根据 zero_centered_gamma 的值设置。如果 zero_centered_gamma 设置为 True,则 scale_init 是 flax.linen.initializers.zeros。否则,scale_init 是 flax.linen.initializers.ones。它应该是一个可调用对象,接受三个参数 (jax.random.PRNGKey, shape, dtype)。

  • scale_axes (Tuple[str, ...], 默认值 = ('embed', )) – 用于通过相应的 mesh 对比例因子 \(\gamma\) 进行分片的轴名称,仅在 enable_layernorm=True 时使用。

  • ln_bias_init (Initializer, 默认值 = flax.linen.initializers.zeros) – 用于初始化偏移因子 \(\beta\),仅在 enable_layernorm=Truelayernorm_type='layernorm' 时使用。它应该是一个可调用对象,接受三个参数 (jax.random.PRNGKey, shape, dtype)。

  • ln_bias_axes (Tuple[str, ...], 默认值 = ('embed', )) – 用于通过相应的 mesh 对偏移因子 \(\beta\) 进行分片的轴名称。仅在 enable_layernorm=Truelayernorm_type='layernorm' 时使用。

  • kernel_init (Initializer, 默认值 =) – flax.linen.initializers.variance_scaling(1.0, ‘fan_in’, ‘truncated_normal’) 用于初始化两次线性变换的权重。它应该是一个可调用对象,接受三个参数 (jax.random.PRNGKey, shape, dtype)。

  • kernel_axes_1 (Tuple[str, ...], 默认值 = ('embed', 'act', 'mlp')) – 用于通过相应的 mesh 对第一次线性变换的权重进行分片的轴名称。

  • kernel_axes_2 (Tuple[str, ...], 默认值 = ('mlp', 'embed')) – 用于通过相应的 mesh 对第二次线性变换的权重进行分片的轴名称。

  • use_bias (bool, 默认值 = False) – 指示是否启用偏置。如果设置为 False,该层将不会学习加性偏置。

  • bias_init (Initializer, 默认值 = flax.linen.initializers.zeros) – 用于初始化偏置,仅在 use_bias=True 时使用。它应该是一个可调用对象,接受三个参数 (jax.random.PRNGKey, shape, dtype)。

  • bias_axes_1 (Tuple[str, ...], 默认值 = ('mlp',)) – 用于通过相应的 mesh 对第一次线性变换的偏置进行分片的轴名称。仅在 use_bias=True 时使用。

  • bias_axes_2 (Tuple[str, ...], 默认值 = ('embed',)) – 用于通过相应的 mesh 对第二次线性变换的偏置进行分片的轴名称。仅在 use_bias=True 时使用。

  • return_layernorm_output (bool, 默认值 = True) – 指示是否返回层归一化的输出。如果设置为 False,则在输出中的第二个张量返回 None。

  • activations (Sequence[Union[str, Callable]], 默认值 = ('relu',)) – 在第一次线性变换后应用的激活函数序列。每个激活函数都有自己的变换层。

  • intermediate_dropout_rng_name (str, 默认值 = 'dropout') – 通过 flax.linen.Module.apply 给定的 RNG 中用于生成 Dropout mask 的键。

  • intermediate_dropout_rate (float, 默认值 = 0.1) – 在 activations 之后的 dropout 操作的 dropout 概率。

  • intermediate_hidden_dropout_dims (Sequence[int], 默认值 = ()) – 隐藏维度中将共享相同 dropout mask 的维度。

  • axis (Union[Iterable[int], int], 默认值 = -1) – 应用变换的整数轴元组。

优化参数
  • dtype (jax.numpy.dtype, 默认值 = jax.numpy.float32) – 用于分配初始参数的数据类型。

  • transpose_batch_sequence (bool, 默认值 = True) – 指示输入张量的批大小和序列长度维度轴是否已调换。如果设置为 True,输入张量应为 (seqlen, batch, hidden),否则为 (batch, seqlen, hidden)。

__call__(inputs: Array, deterministic: bool = False)

对输入应用层归一化,然后应用前馈网络 (MLP 块)。

参数
  • inputs (jax.numpy.ndarray) – 输入张量。

  • deterministic (bool, 默认值 = False) – 如果设置为 True,则禁用 dropout 操作。

返回值

  • outputs (jax.numpy.ndarray) – 输出张量。

  • ln_outputs (jax.numpy.ndarray) – 层归一化的输出张量。如果 return_layernorm_output=False,则此项为 None。

class transformer_engine.jax.flax.RelativePositionBiases(num_buckets, max_distance, num_heads, **kwargs)

将 T5 风格的相对位置嵌入应用到注意力对数 (attention logits)。

参数
  • num_buckets (int) – 用于将 key 和 query 位置之间的距离分桶的桶数。

  • max_distance (int) – 在所有距离都归入最后一个距离桶之前的最大距离。

  • num_attention_heads (int) – Transformer 层中的注意力头数。

  • embedding_init (Initializer, default = flax.linen.linear.default_embed_init) – 用于初始化相对位置嵌入表。

  • embedding_axes (Tuple[str, ...], default = ('heads', 'relpos_buckets')) – 用于与对应网格一起分割嵌入注意力偏置的轴名称。

优化参数

dtype (jax.numpy.dtype, 默认值 = jax.numpy.float32) – 用于分配初始参数的数据类型。

__call__(q_seqlen, k_seqlen, bidirectional=True)

生成相对位置嵌入注意力偏置。

参数
  • q_seqlen (int) – 查询序列的长度。

  • k_seqlen (int) – 键序列的长度。

  • bidirectional (bool, default = True) – 指示是否允许正向记忆-查询相对位置嵌入。

返回值

output – 一个形状为 (1, num_attention_heads, q_seqlen, k_seqlen) 的注意力偏置。

返回类型

jax.numpy.ndarray

class transformer_engine.jax.flax.MultiHeadAttention(head_dim, num_heads, **kwargs)

多头注意力 (MHA),包括查询 (Query)、键 (Key)、值 (Value) 和输出投影。

注意

attn_mask_type 设置为 “causal” 时,参数 mask 将被忽略。

参数
  • head_dim (int) – 每个注意力头的隐藏维度。

  • num_heads (int) – 注意力头的数量。

  • dropout_rate (float, default = 0.0) – 多头注意力操作中的 Dropout 概率。

  • dropout_rng_name (str, default = 'dropout') – 通过 flax.linen.Module.apply 在给定的 RNG 中用于在核心注意力中生成 Dropout 掩码的键。

  • layernorm_type ({'layernorm', 'rmsnorm'}, 默认值 = 'layernorm') – 指示层归一化的类型。

  • layernorm_epsilon (float, default = 1e-6) – 添加到层归一化分母中的一个值,用于数值稳定性。

  • zero_centered_gamma (bool, 默认值 = False) –

    如果设置为 True,LayerNorm 公式变为

    \[y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} * (1 + \gamma) + \beta\]

    此参数仅适用于 'layernorm'。

  • kernel_init (Initializer, default =) – flax.linen.initializers.variance_scaling(1.0, ‘fan_in’, ‘normal’) 用于初始化 QKV 和输出投影权重。它应该是一个可调用对象,接受三个参数 (jax.random.PRNGKey, shape, dtype)。

  • use_bias (bool, default = False) – 指示是否启用 QKVO 投影的偏置偏移。如果设置为 False,层将不会学习加性偏置。

  • bias_init (Initializer, default = flax.linen.initializers.zeros) – 用于初始化 QKVO 投影的偏置,仅在 use_bias=True 时使用。它应该是一个可调用对象,接受三个参数 (jax.random.PRNGKey, shape, dtype)。

  • apply_residual_connection_post_layernorm (bool, default = False) – 指示是否将残差连接应用于层归一化的输出。

  • output_layernorm (bool, default = False) – 指示是否在 MHA 结束时应用层归一化。

  • attn_mask_type ({'causal', 'padding'}, default = 'causal') – 传递给 softmax 操作的注意力掩码类型。在 v0.10.0 中引入。

优化参数
  • dtype (jax.numpy.dtype, 默认值 = jax.numpy.float32) – 用于分配初始参数的数据类型。

  • fuse_qkv (bool, default = True) – 如果设置为 True,此模块为自注意力中的查询-键-值和交叉注意力中的键-值暴露一个单一的融合参数。

  • transpose_batch_sequence (bool, default = True) – 指示输入张量的批次和序列长度维度轴是否已交换。如果设置为 True,输入张量应为 (seqlen, batch, hidden) 形状,否则为 (batch, seqlen, hidden) 形状。

  • scale_attn_logits (bool, default = False) – 指示是否缩放注意力 logits。如果设置为 True,则计算 \(\frac{Q}{\sqrt{head_dim}*K}\),否则计算 \(Q*K\)

  • scaled_query_init (bool, default = True) – 初始化时是否通过 \(\sqrt{head_dim}\) 缩放 WQ。

  • float32_logits (bool, default = False) – 是否在 float32 中计算注意力 logits。

__call__(inputs_q: Array, inputs_kv: Array, mask: Optional[Array] = None, bias: Optional[Array] = None, *, decode: bool = False, deterministic: bool = False)

MultiHeadAttention 层:[查询、键、值投影] -> 点积注意力 -> 输出投影。

参数
  • inputs_q (jax.numpy.ndarray) – 用于查询投影的输入张量。

  • inputs_kv (jax.numpy.ndarray) – 用于键/值投影的输入张量。

  • mask (jax.numpy.ndarray, default = None) – 用于屏蔽自注意力 softmax 输入的布尔张量。

  • bias (jax.numpy.ndarray, default = None) – 用于偏移自注意力 softmax 输入的张量。

  • *

  • decode (bool,default = False) – 指示是否准备和使用自回归缓存。

  • deterministic (bool,default = False) – 如果设置为 True,则禁用 dropout 层。

返回值

outputs – 输出张量。

返回类型

jax.numpy.ndarray

class transformer_engine.jax.flax.TransformerLayer(hidden_size=512, mlp_hidden_size=2048, num_attention_heads=8, **kwargs)

TransformerLayer 由相对位置嵌入、注意力块和前馈网络 (MLP) 组成。此标准层基于论文“Attention Is All You Need”。

注意

self_attn_mask_type 设置为 “causal” 时,参数 attention_mask 将被忽略。

参数
  • hidden_size (int, default = 512) – 每个输入样本的隐藏大小。

  • mlp_hidden_size (int, default = 2048) – 输入样本被投影到的中间大小。

  • num_attention_heads (int, default = 8) – Transformer 层中的注意力头数量。

  • layernorm_type ({'layernorm', 'rmsnorm'}, 默认值 = 'layernorm') – 指示层归一化的类型。

  • layernorm_epsilon (float, default = 1e-6) – 添加到层归一化分母中的一个值,用于数值稳定性。

  • zero_centered_gamma (bool, 默认值 = False) –

    如果设置为 True,LayerNorm 公式变为

    \[y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} * (1 + \gamma) + \beta\]

    此参数仅适用于 'layernorm'。

  • hidden_dropout (float, default = 0.1) – FC2 层后的 dropout 操作的 dropout 概率。

  • hidden_dropout_dims (Sequence[int], default = ()) – 将共享相同隐藏层 dropout 掩码的维度。

  • attention_dropout (float, default = 0.1) – 多头注意力操作中的 dropout 概率。

  • intermediate_dropout (float, default = 0.1) – FC1 层后的 dropout 操作的 dropout 概率。

  • intermediate_dropout_dims (Sequence[int], default = ()) – FC1 层后将共享相同隐藏层 dropout 掩码的维度。

  • dropout_rng_name (str, default = 'dropout') – 通过 flax.linen.Module.apply 在给定的 RNG 中用于在多头注意力中生成 Dropout 掩码的键。

  • mha_kernel_init (Initializer, default =) – flax.linen.initializers.variance_scaling(1.0, ‘fan_in’, ‘normal’) 用于初始化 QKV 和输出投影的权重。它应该是一个可调用对象,接受三个参数 (jax.random.PRNGKey, shape, dtype)。

  • mlp_kernel_init (Initializer, default =) – flax.linen.initializers.variance_scaling(1.0, ‘fan_in’, ‘truncated_normal’) 用于初始化 FC1 和 FC2 层的权重。它应该是一个可调用对象,接受三个参数 (jax.random.PRNGKey, shape, dtype)。

  • mlp_activations (Sequence[str], default = ('relu', )) – 在第一个线性变换后应用的激活函数序列。每个激活函数都有自己的变换层。

  • use_bias (bool, default = False) – 指示是否启用 QKVO 投影、FC1 和 FC2 的偏置偏移。如果设置为 False,层将不会学习加性偏置。

  • bias_init (Initializer, default = flax.linen.initializers.zeros) – 用于初始化 QKVO 投影、FC1 和 FC2 的偏置。仅在 use_bias=True 时使用。它应该是一个可调用对象,接受三个参数 (jax.random.PRNGKey, shape, dtype)。

  • apply_residual_connection_post_layernorm (bool, default = False) – 如果设置为 True,则残差连接取自层归一化的输出(默认为取自层归一化的输入)。

  • output_layernorm (bool, default = False) – 如果设置为 True,则在输出侧、最终 dropout-add 之后应用层归一化。默认行为是在输入侧、QKV 变换之前应用层归一化。

  • float32_attention_logits (bool, default = False) – 如果设置为 True,则在 jax.numpy.float32 中执行注意力 logits。

  • layer_type (TransformerLayerType, default = TransformerLayerType.ENCODER) – 如果设置为 TransformerLayerType.DECODER,则在自注意力后添加一个额外的交叉注意力块。这可以与 TransformerLayerType.ENCODER 选项结合使用,用于像 T5 Transformer 这样的结构。

  • self_attn_mask_type ({'causal', 'padding'}, default = 'causal') – 传递给 softmax 操作的注意力掩码类型。在 v0.10.0 中引入。

  • enable_relative_embedding (bool, default = True) – 是否启用相对位置嵌入作为注意力 logits 的偏移。

  • relative_embedding (flax.linen.Module, default = None) – 用于相对位置嵌入执行的模块,仅在 enable_relative_embedding=True 时使用。默认为 None,如果 enable_relative_embedding=True,将创建一个 RelativePositionBiases 实例。默认值:RelativePositionBiases( num_buckets=32, max_distance=128, num_attention_heads=self.num_attention_heads, dtype=self.dtype, embedding_init=flax.linen.initializers.variance_scaling(1.0, ‘fan_avg’, ‘uniform’), name=’relpos_bias’)

优化参数
  • dtype (jax.numpy.dtype, 默认值 = jax.numpy.float32) – 用于分配初始参数的数据类型。

  • drop_path (float, default = 0.0) – 当 > 0.0 时,在残差块的主路径中对每个样本应用随机深度。

  • fuse_qkv_params (bool, default = True) – 如果设置为 True,TransformerLayer 模块为自注意力中的查询-键-值和交叉注意力中的键-值暴露一个单一的融合参数。

  • transpose_batch_sequence (bool, default = False) – 指示输入张量的批次和序列长度维度轴是否已交换。如果设置为 True,输入张量应为 (seqlen, batch, hidden) 形状,否则为 (batch, seqlen, hidden) 形状。

  • scale_attn_logits (bool, default = False) – 指示是否缩放注意力 logits。如果设置为 True,则计算 \(\frac{Q}{\sqrt{head_dim}*K}\),否则计算 \(Q*K\)

  • scaled_query_init (bool, default = True) – 初始化时是否通过 \(\sqrt{head_dim}\) 缩放 WQ。

__call__(inputs: Array, encoded: Array = None, attention_mask: Array = None, encoder_decoder_mask: Array = None, deterministic: bool = False, decode: bool = False, max_decode_length: bool = None)

Transformer 层:注意力块和前馈网络 (MLP)

参数
  • inputs (jax.numpy.ndarray) – 输入张量。

  • encoded (jax.numpy.ndarray, default = None) – 如果使用 layer_type=TransformerLayerType.DECODER,则将编码器块的输出张量馈送到解码器块中。

  • attention_mask (jax.numpy.ndarray, default = None) – 用于屏蔽自注意力 softmax 输入的布尔张量。

  • encoder_decoder_mask (jax.numpy.ndarray, default = None) – 当 layer_type=TransformerLayerType.DECODER 时,用于屏蔽交叉注意力 softmax 输入的布尔张量。

  • deterministic (bool, default = False) – 如果设置为 True,则禁用 dropout 层。

  • decode (bool,default = False) – 指示是否在多头注意力 (MHA) 中准备和使用自回归缓存。

  • max_decode_length (bool, default = None) – 当 layer_type=TransformerLayerType.DECODERenable_relative_embedding=True 时,生成相对位置嵌入偏置的最大长度。

返回值

outputs – 输出张量。

返回类型

jax.numpy.ndarray

transformer_engine.jax.flax.extend_logical_axis_rules(rules: LogicalRules)

使用预定义的 TransformerLayer 逻辑轴规则扩展给定的 Flax 逻辑轴规则。

注意

我们目前仅支持单 GPU 训练、数据并行训练和 1D 分片张量并行训练的逻辑轴规则。有关 1D 分片张量并行,请参阅Megatron-LM 张量并行论文中的图 3。

警告

请确保在调用此函数之前已通过 fp8_autocast 设置 ShardingResource。

注意

此函数仅在使用 TransformerLayer 时需要。对于其他模块,例如 DenseGeneral,请正确设置核和偏置的轴。

参数

rules (Sequence[Tuple[str, Union[str, None]]]) – 要扩展的基础 Flax 逻辑轴规则。

返回值

extended_rules – 扩展后的 Flax 逻辑轴规则。

返回类型

Sequence[Tuple[str, Union[str, None]]]