rmsnorm.h

RMSNorm 函数。

函数

void nvte_rmsnorm_fwd(const NVTETensor x, const NVTETensor gamma, const float epsilon, NVTETensor z, NVTETensor rsigma, cudaStream_t stream, const int multiprocessorCount, NVTETensor workspace, NVTETensor barrier)

计算输入的 RMSNorm。

如果将 workspace 和 barrier 设置为空张量来调用此函数,它将不会执行操作,而是设置 workspace 和 barrier 张量所需的形状和类型。

参数
  • x[in] 形状为 [N, H] 的输入张量。

  • gamma[in] 形状为 [H] 的 Gamma 张量。

  • epsilon[in] 添加到分母以提高数值稳定性的值。

  • z[inout] 形状为 [N, H] 的输出张量。

  • rsigma[out] 输入在最后一个维度上计算的均方根的倒数。形状:[N]。

  • stream[in] 用于操作的 CUDA 流。

  • multiprocessorCount[in] 设备中的 SM 数量。

  • workspace[out] 工作空间张量。

  • barrier[out] 屏障张量。

void nvte_rmsnorm_bwd(const NVTETensor dz, const NVTETensor x, const NVTETensor rsigma, const gamma, NVTETensor dx, NVTETensor dgamma, NVTETensor dgamma_part, cudaStream_t stream, const int multiprocessorCount, NVTETensor workspace, NVTETensor barrier)

计算 RMSNorm 的反向传播。

如果将 workspace, barrier, dgamma_part 设置为空张量来调用此函数,它将不会执行操作,而是设置这些张量所需的形状和类型。

参数
  • dz[in] 形状为 [N, H] 的输入梯度张量。

  • x[in] 形状为 [N, H] 的前向输入张量。

  • rsigma[in] 输入在最后一个维度上计算的均方根的倒数。形状:[N]。

  • gamma[in] 形状为 [H] 的 Gamma 张量。

  • dx[out] 形状为 [N, H] 的输出梯度。

  • dgamma[out] 形状为 [H] 的 gamma 张量的梯度。

  • dgamma_part[out] 部分 gamma 梯度的存储。

  • stream[in] 用于操作的 CUDA 流。

  • multiprocessorCount[in] 设备中的 SM 数量。

  • workspace[out] 工作空间张量。

  • barrier[out] 屏障张量。