diff --git a/torchrec/modules/crossnet.py b/torchrec/modules/crossnet.py index 43771be5d..5188455d1 100644 --- a/torchrec/modules/crossnet.py +++ b/torchrec/modules/crossnet.py @@ -196,7 +196,7 @@ class VectorCrossNet(torch.nn.Module): On each layer l, the tensor is transformed into - .. math:: x_{l+1} = x_0 * (W_l . x_l + b_l) + x_l + .. math:: x_{l+1} = x_0 * (W_l . x_l) + b_l + x_l where :math:`W_l` is either a vector, :math:`*` means element-wise multiplication; :math:`.` means dot operations.