Skip to content

Commit

Permalink
QuickGELU activation from HuggingFace/Transformers (#475)
Browse files Browse the repository at this point in the history
* Added QuickGELUActivation from HuggingFace/Transformers to common and pytorch

Signed-off-by: Alp Dener <[email protected]>

* Removing 'qgelu' from double-size activations list in LayerNormMLP.

Signed-off-by: Alp Dener <[email protected]>

* indent fix

Signed-off-by: Alp Dener <[email protected]>

---------

Signed-off-by: Alp Dener <[email protected]>
Co-authored-by: Przemyslaw Tredak <[email protected]>
  • Loading branch information
denera and ptrendx authored Feb 17, 2024
1 parent d5c088d commit 0e116d5
Show file tree
Hide file tree
Showing 11 changed files with 252 additions and 12 deletions.
8 changes: 6 additions & 2 deletions tests/pytorch/test_numerics.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def __init__(self, hidden_size, eps, num_attention_heads, embed, num_layers, seq

all_boolean = [True, False]

all_activations = ["gelu", "relu", "reglu", "geglu", "swiglu"]
all_activations = ["gelu", "relu", "reglu", "geglu", "swiglu", "qgelu"]

all_normalizations = ["LayerNorm", "RMSNorm"]

Expand Down Expand Up @@ -304,12 +304,16 @@ def forward(self, x, attention_mask=None):
output = output[0]
return output

class TorchQuickGELU(nn.Module):
def forward(self, input: torch.Tensor) -> torch.Tensor:
return input * torch.sigmoid(1.702 * input)

_supported_act = {'geglu' : nn.GELU(approximate="tanh"),
'gelu' : nn.GELU(approximate="tanh"),
'reglu' : nn.ReLU(),
'relu' : nn.ReLU(),
'swiglu' : nn.SiLU()}
'swiglu' : nn.SiLU(),
'qgelu' : TorchQuickGELU()}


class TorchGLU(nn.Module):
Expand Down
73 changes: 73 additions & 0 deletions transformer_engine/common/activation/gelu.cu
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,57 @@ void dgeglu(const Tensor &grad,
); // NOLINT(*)
}

void qgelu(const Tensor &input,
Tensor *output,
cudaStream_t stream) {
CheckInputTensor(input, "qgelu_input");
CheckOutputTensor(*output, "qgelu_output");
NVTE_CHECK(input.data.shape == output->data.shape, "Input and output shapes must match.");
const size_t tot_elts = product(input.data.shape);

TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT(input.data.dtype, IType,
TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT(output->data.dtype, OType,
constexpr int nvec = 32 / sizeof(IType);
VectorizedUnaryKernelLauncher<nvec, Empty, qgelu<fp32, fp32> >(
reinterpret_cast<const IType*>(input.data.dptr),
reinterpret_cast<OType*>(output->data.dptr),
reinterpret_cast<const fp32*>(output->scale.dptr),
reinterpret_cast<fp32*>(output->amax.dptr),
tot_elts,
Empty(),
stream);
); // NOLINT(*)
); // NOLINT(*)
}

void dqgelu(const Tensor &grad,
const Tensor &input,
Tensor *output,
cudaStream_t stream) {
CheckInputTensor(input, "dqgelu_input");
CheckInputTensor(grad, "dqgelu_input_grad");
CheckOutputTensor(*output, "dqgelu_output");
NVTE_CHECK(input.data.shape == output->data.shape, "Input and output shapes must match.");
NVTE_CHECK(input.data.dtype == grad.data.dtype,
"Input and incoming gradient types must match.");
const size_t tot_elts = product(input.data.shape);

TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT(input.data.dtype, IType,
TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT(output->data.dtype, OType,
constexpr int nvec = 32 / sizeof(IType);
VectorizedUnaryGradKernelLauncher<nvec, Empty, dqgelu<fp32, fp32>>(
reinterpret_cast<const IType*>(grad.data.dptr),
reinterpret_cast<const IType*>(input.data.dptr),
reinterpret_cast<OType*>(output->data.dptr),
reinterpret_cast<const fp32*>(output->scale.dptr),
reinterpret_cast<fp32*>(output->amax.dptr),
tot_elts,
{},
stream);
); // NOLINT(*)
); // NOLINT(*)
}

} // namespace transformer_engine

void nvte_gelu(const NVTETensor input,
Expand Down Expand Up @@ -172,3 +223,25 @@ void nvte_dgeglu(const NVTETensor grad,
reinterpret_cast<Tensor*>(output),
stream);
}

void nvte_qgelu(const NVTETensor input,
NVTETensor output,
cudaStream_t stream) {
NVTE_API_CALL(nvte_qgelu);
using namespace transformer_engine;
qgelu(*reinterpret_cast<const Tensor*>(input),
reinterpret_cast<Tensor*>(output),
stream);
}

void nvte_dqgelu(const NVTETensor grad,
const NVTETensor input,
NVTETensor output,
cudaStream_t stream) {
NVTE_API_CALL(nvte_dqgelu);
using namespace transformer_engine;
dqgelu(*reinterpret_cast<const Tensor*>(grad),
*reinterpret_cast<const Tensor*>(input),
reinterpret_cast<Tensor*>(output),
stream);
}
36 changes: 29 additions & 7 deletions transformer_engine/common/include/transformer_engine/activation.h
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,8 @@ void nvte_dgelu(const NVTETensor grad,
* \param[in] stream CUDA stream used for the operation.
*/
void nvte_geglu(const NVTETensor input,
NVTETensor output,
cudaStream_t stream);
NVTETensor output,
cudaStream_t stream);

/*! \brief Compute GeGLU gradient.
* \param[in] grad Incoming gradient of shape [N, H].
Expand Down Expand Up @@ -113,8 +113,8 @@ void nvte_dswiglu(const NVTETensor grad,
* \param[in] stream CUDA stream used for the operation.
*/
void nvte_reglu(const NVTETensor input,
NVTETensor output,
cudaStream_t stream);
NVTETensor output,
cudaStream_t stream);

/*! \brief Compute ReGLU gradient.
* \param[in] grad Incoming gradient of shape [N, H].
Expand All @@ -123,9 +123,31 @@ void nvte_reglu(const NVTETensor input,
* \param[in] stream CUDA stream used for the operation.
*/
void nvte_dreglu(const NVTETensor grad,
const NVTETensor input,
NVTETensor output,
cudaStream_t stream);
const NVTETensor input,
NVTETensor output,
cudaStream_t stream);

/*! \brief Compute QuickGELU activation of the input.
*
* \param[in] input Input tensor for QuickGELU activation.
* \param[in,out] output Output tensor. Approximates GELU as input x sigmoid(1.702 x input).
* \param[in] stream CUDA stream used for the operation.
*/
void nvte_qgelu(const NVTETensor input,
NVTETensor output,
cudaStream_t stream);

/*! \brief Compute QuickGELU activation gradient.
*
* \param[in] grad Incoming gradient.
* \param[in] input Input tensor for QuickGELU activation.
* \param[in,out] output Output tensor.
* \param[in] stream CUDA stream used for the operation.
*/
void nvte_dqgelu(const NVTETensor grad,
const NVTETensor input,
NVTETensor output,
cudaStream_t stream);

#ifdef __cplusplus
} // extern "C"
Expand Down
13 changes: 13 additions & 0 deletions transformer_engine/common/util/math.h
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,19 @@ __device__ inline OType dsigmoid(const IType val, const Empty& e) {
return s * (1.f - s);
}

template <typename OType, typename IType>
__device__ inline OType qgelu(const IType val, const Empty& e) {
const float cval = val;
return cval * sigmoid<float, float>(1.702f * cval, e);
}

template <typename OType, typename IType>
__device__ inline OType dqgelu(const IType val, const Empty& e) {
const float cval = val;
return cval * dsigmoid<float, float>(1.702f * cval, e) +
sigmoid<float, float>(1.702f * cval, e);
}

template <typename OType, typename IType>
__device__ inline OType swish(const IType val, const Empty& e) {
const float cval = val;
Expand Down
28 changes: 27 additions & 1 deletion transformer_engine/pytorch/cpp_extensions/activation.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import transformer_engine_extensions as tex


__all__ = ['gelu', 'relu', 'reglu', 'geglu', 'swiglu']
__all__ = ['gelu', 'relu', 'reglu', 'geglu', 'swiglu', 'qgelu']


def gelu(
Expand Down Expand Up @@ -140,3 +140,29 @@ def swiglu(
fp8_tensor,
otype,
)


def qgelu(
inp: torch.Tensor,
fp8_meta_tensor: tex.FP8TensorMeta,
fp8_tensor: Union[tex.FP8FwdTensors, tex.FP8BwdTensors],
otype: tex.DType,
) -> torch.Tensor:
"""QuickGELU with FP8 output"""
empty_tensor = torch.Tensor()
if fp8_meta_tensor is not None:
scale = fp8_meta_tensor.scale
amax_history = fp8_meta_tensor.amax_history
scale_inv = fp8_meta_tensor.scale_inv
else:
scale = empty_tensor
amax_history = empty_tensor
scale_inv = empty_tensor
return torch.ops.tex_ts.qgelu_ts(
inp,
scale,
amax_history,
scale_inv,
fp8_tensor,
otype,
)
12 changes: 12 additions & 0 deletions transformer_engine/pytorch/csrc/extensions.h
Original file line number Diff line number Diff line change
Expand Up @@ -302,6 +302,13 @@ at::Tensor swiglu(at::Tensor input,
transformer_engine::DType otype
);

at::Tensor qgelu(at::Tensor input,
at::Tensor scale,
at::Tensor amax,
at::Tensor scale_inv,
transformer_engine::DType otype
);

at::Tensor dgelu(at::Tensor grad,
at::Tensor input,
transformer_engine::DType otype
Expand All @@ -327,6 +334,11 @@ at::Tensor dswiglu(at::Tensor grad,
transformer_engine::DType otype
);

at::Tensor dqgelu(at::Tensor grad,
at::Tensor input,
transformer_engine::DType otype
);

/***************************************************************************************************
* LayerNorm
**************************************************************************************************/
Expand Down
52 changes: 52 additions & 0 deletions transformer_engine/pytorch/csrc/extensions/activation.cu
Original file line number Diff line number Diff line change
Expand Up @@ -265,3 +265,55 @@ at::Tensor dswiglu(at::Tensor grad,

return output;
}

at::Tensor qgelu(at::Tensor input,
at::Tensor scale,
at::Tensor amax,
at::Tensor scale_inv,
transformer_engine::DType otype
) {
using namespace transformer_engine;

size_t N = static_cast<size_t>(input.size(-1));
size_t M = input.numel() / N;

auto output =
allocateTorchTensor(M,
N,
otype);

auto itype = GetTransformerEngineDType(input.scalar_type());
auto input_cu = makeTransformerEngineTensor(input.data_ptr(), {M, N}, itype);
auto output_cu = makeTransformerEngineTensor(output.data_ptr(), {M, N}, otype,
amax.data_ptr(), scale.data_ptr(),
scale_inv.data_ptr());

nvte_qgelu(input_cu.data(), output_cu.data(), at::cuda::getCurrentCUDAStream());

return output;
}

at::Tensor dqgelu(at::Tensor grad,
at::Tensor input,
transformer_engine::DType otype
) {
using namespace transformer_engine;

size_t N = static_cast<size_t>(input.size(-1));
size_t M = input.numel() / N;

auto output =
allocateTorchTensor(M,
N,
otype);

auto itype = GetTransformerEngineDType(input.scalar_type());
auto gtype = GetTransformerEngineDType(grad.scalar_type());
auto input_cu = makeTransformerEngineTensor(input.data_ptr(), {M, N}, itype);
auto grad_cu = makeTransformerEngineTensor(grad.data_ptr(), {M, N}, gtype);
auto output_cu = makeTransformerEngineTensor(output.data_ptr(), {M, N}, otype);

nvte_dqgelu(grad_cu.data(), input_cu.data(), output_cu.data(), at::cuda::getCurrentCUDAStream());

return output;
}
2 changes: 2 additions & 0 deletions transformer_engine/pytorch/csrc/extensions/pybind.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -72,11 +72,13 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("geglu", &geglu, "GeGLU with FP8 output");
m.def("reglu", &reglu, "ReGLU with FP8 output");
m.def("swiglu", &swiglu, "SwiGLU with FP8 output");
m.def("qgelu", &qgelu, "QuickGELU with FP8 output");
m.def("dgelu", &dgelu, "Backward of GeLU");
m.def("drelu", &drelu, "Backward of ReLU");
m.def("dgeglu", &dgeglu, "Backward of GeGLU");
m.def("dreglu", &dreglu, "Backward of ReGLU");
m.def("dswiglu", &dswiglu, "Backward of SwiGLU");
m.def("dqgelu", &dqgelu, "Backward of QuickGELU");
m.def("fa_prepare_fwd", &fa_prepare_fwd, "Prepare QKV for Flash Attention");
m.def("fa_prepare_bwd", &fa_prepare_bwd, "Backward of QKV preparation for Flash Attention");
m.def("get_fused_attn_backend", &get_fused_attn_backend, "Get Fused Attention backend");
Expand Down
35 changes: 35 additions & 0 deletions transformer_engine/pytorch/csrc/ts_fp8_op.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -247,6 +247,40 @@ at::Tensor swiglu_ts(at::Tensor input,
return output;
}

at::Tensor qgelu_ts(at::Tensor input,
at::Tensor scale,
at::Tensor amax,
at::Tensor scale_inv,
int64_t fp8_tensor,
int64_t otype) {
transformer_engine::DType otype_arg = reverse_map_dtype(otype);

at::Tensor s, a, s_inv;
if (scale.numel()) {
s = scale[fp8_tensor];
} else {
s = scale;
}

if (amax.numel()) {
a = amax[0][fp8_tensor];
} else {
a = amax;
}

if (scale_inv.numel()) {
s_inv = scale_inv[fp8_tensor];
} else {
s_inv = scale_inv;
}

at::Tensor output = qgelu(input,
s,
a,
s_inv,
otype_arg);
return output;
}

at::Tensor te_gemm_ts(at::Tensor A,
at::Tensor A_scale_inverse,
Expand Down Expand Up @@ -406,6 +440,7 @@ TORCH_LIBRARY(tex_ts, m) {
m.def("geglu_ts", &geglu_ts);
m.def("reglu_ts", &reglu_ts);
m.def("swiglu_ts", &swiglu_ts);
m.def("qgelu_ts", &qgelu_ts);
m.def("te_gemm_ts", &te_gemm_ts);
m.def("layernorm_fwd_fp8_inf_ts", &layernorm_fwd_fp8_inf_ts);
m.def("layernorm_fwd_inf_ts", &layernorm_fwd_inf_ts);
Expand Down
3 changes: 2 additions & 1 deletion transformer_engine/pytorch/module/layernorm_mlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ def _act_func(activation: str):
'geglu': (tex.geglu, tex.dgeglu),
'reglu': (tex.reglu, tex.dreglu),
'swiglu': (tex.swiglu, tex.dswiglu),
'qgelu': (tex.qgelu, tex.dqgelu)
}
if activation not in funcs:
raise NotImplementedError("Activation type " + activation + " is not supported!")
Expand Down Expand Up @@ -1078,7 +1079,7 @@ class LayerNormMLP(TransformerEngineBaseModule):
type of normalization applied.
activation : str, default = 'gelu'
activation function used.
Options: 'gelu', 'geglu', 'relu', 'reglu', 'squared_relu', 'swiglu'.
Options: 'gelu', 'geglu', 'relu', 'reglu', 'squared_relu', 'swiglu', 'qgelu'.
init_method : Callable, default = `None`
used for initializing FC1 weights in the following way: `init_method(weight)`.
When set to `None`, defaults to `torch.nn.init.normal_(mean=0.0, std=0.023)`.
Expand Down
2 changes: 1 addition & 1 deletion transformer_engine/pytorch/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,7 @@ class TransformerLayer(torch.nn.Module):
if set to `False`, the transformer layer will not learn any additive biases.
activation : str, default = 'gelu'
Type of activation used in MLP block.
Options are: 'gelu', 'relu', 'reglu', 'geglu' and 'swiglu'.
Options are: 'gelu', 'relu', 'reglu', 'geglu', 'swiglu' and 'qgelu'.
device : Union[torch.device, str], default = "cuda"
The device on which the parameters of the model will allocated. It is the user's
responsibility to ensure all parameters are moved to the GPU before running the
Expand Down

0 comments on commit 0e116d5

Please sign in to comment.