Skip to content

Commit 2d5d355

Browse files
committed
reword latex
Signed-off-by: Phuong Nguyen <[email protected]>
1 parent 9fcc51a commit 2d5d355

File tree

1 file changed

+7
-10
lines changed

1 file changed

+7
-10
lines changed

transformer_engine/common/include/transformer_engine/normalization.h

Lines changed: 7 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ extern "C" {
3838
* the last dimension. Shape: [N].
3939
* \param[out] workspace Workspace tensor.
4040
* \param[in] multiprocessorCount Number of SMs in the device.
41-
* \param[in] zero_centered_gamma If zero_centered_gamma is enabled
41+
* \param[in] zero_centered_gamma Multiply normalized values by @f$ \gamma+1 @f$ instead of @f$ \gamma @f$
4242
* \param[in] stream CUDA stream used for the operation.
4343
*/
4444
void nvte_layernorm_fwd(const NVTETensor x, const NVTETensor gamma, const NVTETensor beta,
@@ -70,15 +70,12 @@ void nvte_layernorm_fwd(const NVTETensor x, const NVTETensor gamma, const NVTETe
7070
* \param[out] dbeta Gradient for beta tensor of shape [H].
7171
* \param[out] workspace Workspace tensor.
7272
* \param[in] multiprocessorCount Number of SMs in the device.
73-
* \param[in] zero_centered_gamma If zero_centered_gamma is enabled
73+
* \param[in] zero_centered_gamma Multiply normalized values by @f$ \gamma+1 @f$ instead of @f$ \gamma @f$
7474
* \param[in] stream CUDA stream used for the operation.
7575
*/
76-
void nvte_layernorm_bwd(const NVTETensor dz, // BxSxhidden_size
77-
const NVTETensor x, // BxSxhidden_size
78-
const NVTETensor mu, // BxS, FP32!
79-
const NVTETensor rsigma, // BxS, FP32!
80-
const NVTETensor gamma, // hidden_size
81-
NVTETensor dx, NVTETensor dgamma, NVTETensor dbeta, NVTETensor workspace,
76+
void nvte_layernorm_bwd(const NVTETensor dz, const NVTETensor x, const NVTETensor mu,
77+
const NVTETensor rsigma, const NVTETensor gamma, NVTETensor dx,
78+
NVTETensor dgamma, NVTETensor dbeta, NVTETensor workspace,
8279
const int multiprocessorCount, const bool zero_centered_gamma,
8380
cudaStream_t stream);
8481

@@ -105,7 +102,7 @@ void nvte_layernorm_bwd(const NVTETensor dz, // BxSxhidden_size
105102
* calculated over the last dimension. Shape: [N].
106103
* \param[out] workspace Workspace tensor.
107104
* \param[in] multiprocessorCount Number of SMs in the device.
108-
* \param[in] zero_centered_gamma If zero_centered_gamma is enabled
105+
* \param[in] zero_centered_gamma Multiply normalized values by @f$ \gamma+1 @f$ instead of @f$ \gamma @f$
109106
* \param[in] stream CUDA stream used for the operation.
110107
*/
111108
void nvte_rmsnorm_fwd(const NVTETensor x, const NVTETensor gamma, const float epsilon, NVTETensor z,
@@ -137,7 +134,7 @@ void nvte_rmsnorm_fwd(const NVTETensor x, const NVTETensor gamma, const float ep
137134
* \param[out] dgamma Gradient for gamma tensor of shape [H].
138135
* \param[out] workspace Workspace tensor.
139136
* \param[in] multiprocessorCount Number of SMs in the device.
140-
* \param[in] zero_centered_gamma If zero_centered_gamma is enabled
137+
* \param[in] zero_centered_gamma Multiply normalized values by @f$ \gamma+1 @f$ instead of @f$ \gamma @f$
141138
* \param[in] stream CUDA stream used for the operation.
142139
*/
143140
void nvte_rmsnorm_bwd(const NVTETensor dz, const NVTETensor x, const NVTETensor rsigma,

0 commit comments

Comments
 (0)