@@ -38,7 +38,7 @@ extern "C" {
38
38
* the last dimension. Shape: [N].
39
39
* \param[out] workspace Workspace tensor.
40
40
* \param[in] multiprocessorCount Number of SMs in the device.
41
- * \param[in] zero_centered_gamma If zero_centered_gamma is enabled
41
+ * \param[in] zero_centered_gamma Multiply normalized values by @f$ \gamma+1 @f$ instead of @f$ \gamma @f$
42
42
* \param[in] stream CUDA stream used for the operation.
43
43
*/
44
44
void nvte_layernorm_fwd (const NVTETensor x , const NVTETensor gamma , const NVTETensor beta ,
@@ -70,15 +70,12 @@ void nvte_layernorm_fwd(const NVTETensor x, const NVTETensor gamma, const NVTETe
70
70
* \param[out] dbeta Gradient for beta tensor of shape [H].
71
71
* \param[out] workspace Workspace tensor.
72
72
* \param[in] multiprocessorCount Number of SMs in the device.
73
- * \param[in] zero_centered_gamma If zero_centered_gamma is enabled
73
+ * \param[in] zero_centered_gamma Multiply normalized values by @f$ \gamma+1 @f$ instead of @f$ \gamma @f$
74
74
* \param[in] stream CUDA stream used for the operation.
75
75
*/
76
- void nvte_layernorm_bwd (const NVTETensor dz , // BxSxhidden_size
77
- const NVTETensor x , // BxSxhidden_size
78
- const NVTETensor mu , // BxS, FP32!
79
- const NVTETensor rsigma , // BxS, FP32!
80
- const NVTETensor gamma , // hidden_size
81
- NVTETensor dx , NVTETensor dgamma , NVTETensor dbeta , NVTETensor workspace ,
76
+ void nvte_layernorm_bwd (const NVTETensor dz , const NVTETensor x , const NVTETensor mu ,
77
+ const NVTETensor rsigma , const NVTETensor gamma , NVTETensor dx ,
78
+ NVTETensor dgamma , NVTETensor dbeta , NVTETensor workspace ,
82
79
const int multiprocessorCount , const bool zero_centered_gamma ,
83
80
cudaStream_t stream );
84
81
@@ -105,7 +102,7 @@ void nvte_layernorm_bwd(const NVTETensor dz, // BxSxhidden_size
105
102
* calculated over the last dimension. Shape: [N].
106
103
* \param[out] workspace Workspace tensor.
107
104
* \param[in] multiprocessorCount Number of SMs in the device.
108
- * \param[in] zero_centered_gamma If zero_centered_gamma is enabled
105
+ * \param[in] zero_centered_gamma Multiply normalized values by @f$ \gamma+1 @f$ instead of @f$ \gamma @f$
109
106
* \param[in] stream CUDA stream used for the operation.
110
107
*/
111
108
void nvte_rmsnorm_fwd (const NVTETensor x , const NVTETensor gamma , const float epsilon , NVTETensor z ,
@@ -137,7 +134,7 @@ void nvte_rmsnorm_fwd(const NVTETensor x, const NVTETensor gamma, const float ep
137
134
* \param[out] dgamma Gradient for gamma tensor of shape [H].
138
135
* \param[out] workspace Workspace tensor.
139
136
* \param[in] multiprocessorCount Number of SMs in the device.
140
- * \param[in] zero_centered_gamma If zero_centered_gamma is enabled
137
+ * \param[in] zero_centered_gamma Multiply normalized values by @f$ \gamma+1 @f$ instead of @f$ \gamma @f$
141
138
* \param[in] stream CUDA stream used for the operation.
142
139
*/
143
140
void nvte_rmsnorm_bwd (const NVTETensor dz , const NVTETensor x , const NVTETensor rsigma ,
0 commit comments