Skip to content

Commit 643e867

Browse files
committed
Added QuickGELUActivation from HuggingFace/Transformers to common and pytorch
Signed-off-by: Alp Dener <[email protected]>
1 parent 2574a1c commit 643e867

File tree

11 files changed

+247
-6
lines changed

11 files changed

+247
-6
lines changed

tests/pytorch/test_numerics.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ def __init__(self, hidden_size, eps, num_attention_heads, embed, num_layers, seq
5757

5858
all_boolean = [True, False]
5959

60-
all_activations = ["gelu", "relu", "reglu", "geglu", "swiglu"]
60+
all_activations = ["gelu", "relu", "reglu", "geglu", "swiglu", "qgelu"]
6161

6262
all_normalizations = ["LayerNorm", "RMSNorm"]
6363

@@ -382,12 +382,16 @@ def forward(self, x, attention_mask=None):
382382
output = output[0]
383383
return output
384384

385+
class TorchQuickGELU(nn.Module):
386+
def forward(self, input: torch.Tensor) -> torch.Tensor:
387+
return input * torch.sigmoid(1.702 * input)
385388

386389
_supported_act = {'geglu' : nn.GELU(approximate="tanh"),
387390
'gelu' : nn.GELU(approximate="tanh"),
388391
'reglu' : nn.ReLU(),
389392
'relu' : nn.ReLU(),
390-
'swiglu' : nn.SiLU()}
393+
'swiglu' : nn.SiLU(),
394+
'qgelu' : TorchQuickGELU()}
391395

392396

393397
class TorchGLU(nn.Module):

transformer_engine/common/activation/gelu.cu

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -127,6 +127,57 @@ void dgeglu(const Tensor &grad,
127127
); // NOLINT(*)
128128
}
129129

130+
void qgelu(const Tensor &input,
131+
Tensor *output,
132+
cudaStream_t stream) {
133+
CheckInputTensor(input, "qgelu_input");
134+
CheckOutputTensor(*output, "qgelu_output");
135+
NVTE_CHECK(input.data.shape == output->data.shape, "Input and output shapes must match.");
136+
const size_t tot_elts = product(input.data.shape);
137+
138+
TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT(input.data.dtype, IType,
139+
TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT(output->data.dtype, OType,
140+
constexpr int nvec = 32 / sizeof(IType);
141+
VectorizedUnaryKernelLauncher<nvec, Empty, qgelu<fp32, fp32> >(
142+
reinterpret_cast<const IType*>(input.data.dptr),
143+
reinterpret_cast<OType*>(output->data.dptr),
144+
reinterpret_cast<const fp32*>(output->scale.dptr),
145+
reinterpret_cast<fp32*>(output->amax.dptr),
146+
tot_elts,
147+
Empty(),
148+
stream);
149+
); // NOLINT(*)
150+
); // NOLINT(*)
151+
}
152+
153+
void dqgelu(const Tensor &grad,
154+
const Tensor &input,
155+
Tensor *output,
156+
cudaStream_t stream) {
157+
CheckInputTensor(input, "dqgelu_input");
158+
CheckInputTensor(grad, "dqgelu_input_grad");
159+
CheckOutputTensor(*output, "dqgelu_output");
160+
NVTE_CHECK(input.data.shape == output->data.shape, "Input and output shapes must match.");
161+
NVTE_CHECK(input.data.dtype == grad.data.dtype,
162+
"Input and incoming gradient types must match.");
163+
const size_t tot_elts = product(input.data.shape);
164+
165+
TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT(input.data.dtype, IType,
166+
TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT(output->data.dtype, OType,
167+
constexpr int nvec = 32 / sizeof(IType);
168+
VectorizedUnaryGradKernelLauncher<nvec, Empty, dqgelu<fp32, fp32>>(
169+
reinterpret_cast<const IType*>(grad.data.dptr),
170+
reinterpret_cast<const IType*>(input.data.dptr),
171+
reinterpret_cast<OType*>(output->data.dptr),
172+
reinterpret_cast<const fp32*>(output->scale.dptr),
173+
reinterpret_cast<fp32*>(output->amax.dptr),
174+
tot_elts,
175+
{},
176+
stream);
177+
); // NOLINT(*)
178+
); // NOLINT(*)
179+
}
180+
130181
} // namespace transformer_engine
131182

132183
void nvte_gelu(const NVTETensor input,
@@ -172,3 +223,25 @@ void nvte_dgeglu(const NVTETensor grad,
172223
reinterpret_cast<Tensor*>(output),
173224
stream);
174225
}
226+
227+
void nvte_qgelu(const NVTETensor input,
228+
NVTETensor output,
229+
cudaStream_t stream) {
230+
NVTE_API_CALL(nvte_qgelu);
231+
using namespace transformer_engine;
232+
qgelu(*reinterpret_cast<const Tensor*>(input),
233+
reinterpret_cast<Tensor*>(output),
234+
stream);
235+
}
236+
237+
void nvte_dqgelu(const NVTETensor grad,
238+
const NVTETensor input,
239+
NVTETensor output,
240+
cudaStream_t stream) {
241+
NVTE_API_CALL(nvte_dqgelu);
242+
using namespace transformer_engine;
243+
dqgelu(*reinterpret_cast<const Tensor*>(grad),
244+
*reinterpret_cast<const Tensor*>(input),
245+
reinterpret_cast<Tensor*>(output),
246+
stream);
247+
}

transformer_engine/common/include/transformer_engine/activation.h

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -127,6 +127,28 @@ void nvte_dreglu(const NVTETensor grad,
127127
NVTETensor output,
128128
cudaStream_t stream);
129129

130+
/*! \brief Compute QuickGELU activation of the input.
131+
*
132+
* \param[in] input Input tensor for QuickGELU activation.
133+
* \param[in,out] output Output tensor. Approximates GELU as input x sigmoid(1.702 x input).
134+
* \param[in] stream CUDA stream used for the operation.
135+
*/
136+
void nvte_qgelu(const NVTETensor input,
137+
NVTETensor output,
138+
cudaStream_t stream);
139+
140+
/*! \brief Compute QuickGELU activation gradient.
141+
*
142+
* \param[in] grad Incoming gradient.
143+
* \param[in] input Input tensor for QuickGELU activation.
144+
* \param[in,out] output Output tensor.
145+
* \param[in] stream CUDA stream used for the operation.
146+
*/
147+
void nvte_dqgelu(const NVTETensor grad,
148+
const NVTETensor input,
149+
NVTETensor output,
150+
cudaStream_t stream);
151+
130152
#ifdef __cplusplus
131153
} // extern "C"
132154
#endif

transformer_engine/common/util/math.h

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,19 @@ __device__ inline OType dsigmoid(const IType val, const Empty& e) {
3939
return s * (1.f - s);
4040
}
4141

42+
template <typename OType, typename IType>
43+
__device__ inline OType qgelu(const IType val, const Empty& e) {
44+
const float cval = val;
45+
return cval * sigmoid<float, float>(1.702f * cval, e);
46+
}
47+
48+
template <typename OType, typename IType>
49+
__device__ inline OType dqgelu(const IType val, const Empty& e) {
50+
const float cval = val;
51+
return cval * dsigmoid<float, float>(1.702f * cval, e) +
52+
sigmoid<float, float>(1.702f * cval, e);
53+
}
54+
4255
template <typename OType, typename IType>
4356
__device__ inline OType swish(const IType val, const Empty& e) {
4457
const float cval = val;

transformer_engine/pytorch/cpp_extensions/activation.py

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
import transformer_engine_extensions as tex
99

1010

11-
__all__ = ['gelu', 'relu', 'reglu', 'geglu', 'swiglu']
11+
__all__ = ['gelu', 'relu', 'reglu', 'geglu', 'swiglu', 'qgelu']
1212

1313

1414
def gelu(
@@ -140,3 +140,29 @@ def swiglu(
140140
fp8_tensor,
141141
otype,
142142
)
143+
144+
145+
def qgelu(
146+
inp: torch.Tensor,
147+
fp8_meta_tensor: tex.FP8TensorMeta,
148+
fp8_tensor: Union[tex.FP8FwdTensors, tex.FP8BwdTensors],
149+
otype: tex.DType,
150+
) -> torch.Tensor:
151+
"""QuickGELU with FP8 output"""
152+
empty_tensor = torch.Tensor()
153+
if fp8_meta_tensor is not None:
154+
scale = fp8_meta_tensor.scale
155+
amax_history = fp8_meta_tensor.amax_history
156+
scale_inv = fp8_meta_tensor.scale_inv
157+
else:
158+
scale = empty_tensor
159+
amax_history = empty_tensor
160+
scale_inv = empty_tensor
161+
return torch.ops.tex_ts.qgelu_ts(
162+
inp,
163+
scale,
164+
amax_history,
165+
scale_inv,
166+
fp8_tensor,
167+
otype,
168+
)

transformer_engine/pytorch/csrc/extensions.h

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -295,6 +295,13 @@ at::Tensor swiglu(at::Tensor input,
295295
transformer_engine::DType otype
296296
);
297297

298+
at::Tensor qgelu(at::Tensor input,
299+
at::Tensor scale,
300+
at::Tensor amax,
301+
at::Tensor scale_inv,
302+
transformer_engine::DType otype
303+
);
304+
298305
at::Tensor dgelu(at::Tensor grad,
299306
at::Tensor input,
300307
transformer_engine::DType otype
@@ -320,6 +327,11 @@ at::Tensor dswiglu(at::Tensor grad,
320327
transformer_engine::DType otype
321328
);
322329

330+
at::Tensor dqgelu(at::Tensor grad,
331+
at::Tensor input,
332+
transformer_engine::DType otype
333+
);
334+
323335
/***************************************************************************************************
324336
* LayerNorm
325337
**************************************************************************************************/

transformer_engine/pytorch/csrc/extensions/activation.cu

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -265,3 +265,55 @@ at::Tensor dswiglu(at::Tensor grad,
265265

266266
return output;
267267
}
268+
269+
at::Tensor qgelu(at::Tensor input,
270+
at::Tensor scale,
271+
at::Tensor amax,
272+
at::Tensor scale_inv,
273+
transformer_engine::DType otype
274+
) {
275+
using namespace transformer_engine;
276+
277+
size_t N = static_cast<size_t>(input.size(-1));
278+
size_t M = input.numel() / N;
279+
280+
auto output =
281+
allocateTorchTensor(M,
282+
N,
283+
otype);
284+
285+
auto itype = GetTransformerEngineDType(input.scalar_type());
286+
auto input_cu = makeTransformerEngineTensor(input.data_ptr(), {M, N}, itype);
287+
auto output_cu = makeTransformerEngineTensor(output.data_ptr(), {M, N}, otype,
288+
amax.data_ptr(), scale.data_ptr(),
289+
scale_inv.data_ptr());
290+
291+
nvte_qgelu(input_cu.data(), output_cu.data(), at::cuda::getCurrentCUDAStream());
292+
293+
return output;
294+
}
295+
296+
at::Tensor dqgelu(at::Tensor grad,
297+
at::Tensor input,
298+
transformer_engine::DType otype
299+
) {
300+
using namespace transformer_engine;
301+
302+
size_t N = static_cast<size_t>(input.size(-1));
303+
size_t M = input.numel() / N;
304+
305+
auto output =
306+
allocateTorchTensor(M,
307+
N,
308+
otype);
309+
310+
auto itype = GetTransformerEngineDType(input.scalar_type());
311+
auto gtype = GetTransformerEngineDType(grad.scalar_type());
312+
auto input_cu = makeTransformerEngineTensor(input.data_ptr(), {M, N}, itype);
313+
auto grad_cu = makeTransformerEngineTensor(grad.data_ptr(), {M, N}, gtype);
314+
auto output_cu = makeTransformerEngineTensor(output.data_ptr(), {M, N}, otype);
315+
316+
nvte_dqgelu(grad_cu.data(), input_cu.data(), output_cu.data(), at::cuda::getCurrentCUDAStream());
317+
318+
return output;
319+
}

transformer_engine/pytorch/csrc/extensions/pybind.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,11 +66,13 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
6666
m.def("geglu", &geglu, "GeGLU with FP8 output");
6767
m.def("reglu", &reglu, "ReGLU with FP8 output");
6868
m.def("swiglu", &swiglu, "SwiGLU with FP8 output");
69+
m.def("qgelu", &qgelu, "QuickGELU with FP8 output");
6970
m.def("dgelu", &dgelu, "Backward of GeLU");
7071
m.def("drelu", &drelu, "Backward of ReLU");
7172
m.def("dgeglu", &dgeglu, "Backward of GeGLU");
7273
m.def("dreglu", &dreglu, "Backward of ReGLU");
7374
m.def("dswiglu", &dswiglu, "Backward of SwiGLU");
75+
m.def("dqgelu", &dqgelu, "Backward of QuickGELU");
7476
m.def("fa_prepare_fwd", &fa_prepare_fwd, "Prepare QKV for Flash Attention");
7577
m.def("fa_prepare_bwd", &fa_prepare_bwd, "Backward of QKV preparation for Flash Attention");
7678
m.def("get_fused_attn_backend", &get_fused_attn_backend, "Get Fused Attention backend");

transformer_engine/pytorch/csrc/ts_fp8_op.cpp

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -222,6 +222,41 @@ at::Tensor swiglu_ts(at::Tensor input,
222222
return output;
223223
}
224224

225+
at::Tensor qgelu_ts(at::Tensor input,
226+
at::Tensor scale,
227+
at::Tensor amax,
228+
at::Tensor scale_inv,
229+
int64_t fp8_tensor,
230+
int64_t otype) {
231+
transformer_engine::DType otype_arg = reverse_map_dtype(otype);
232+
233+
at::Tensor s, a, s_inv;
234+
if (scale.numel()) {
235+
s = scale[fp8_tensor];
236+
} else {
237+
s = scale;
238+
}
239+
240+
if (amax.numel()) {
241+
a = amax[0][fp8_tensor];
242+
} else {
243+
a = amax;
244+
}
245+
246+
if (scale_inv.numel()) {
247+
s_inv = scale_inv[fp8_tensor];
248+
} else {
249+
s_inv = scale_inv;
250+
}
251+
252+
at::Tensor output = qgelu(input,
253+
s,
254+
a,
255+
s_inv,
256+
otype_arg);
257+
return output;
258+
}
259+
225260
at::Tensor te_gemm_ts(at::Tensor A,
226261
at::Tensor A_scale_inverse,
227262
int64_t A_fp8_tensor,
@@ -374,6 +409,7 @@ TORCH_LIBRARY(tex_ts, m) {
374409
m.def("geglu_ts", &geglu_ts);
375410
m.def("reglu_ts", &reglu_ts);
376411
m.def("swiglu_ts", &swiglu_ts);
412+
m.def("qgelu_ts", &qgelu_ts);
377413
m.def("te_gemm_ts", &te_gemm_ts);
378414
m.def("layernorm_fwd_fp8_inf_ts", &layernorm_fwd_fp8_inf_ts);
379415
m.def("layernorm_fwd_inf_ts", &layernorm_fwd_inf_ts);

transformer_engine/pytorch/module/layernorm_mlp.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@ def _act_func(activation: str):
6060
'geglu': (tex.geglu, tex.dgeglu),
6161
'reglu': (tex.reglu, tex.dreglu),
6262
'swiglu': (tex.swiglu, tex.dswiglu),
63+
'qgelu': (tex.qgelu, tex.dqgelu)
6364
}
6465
if activation not in funcs:
6566
raise NotImplementedError("Activation type " + activation + " is not supported!")
@@ -930,7 +931,7 @@ class LayerNormMLP(TransformerEngineBaseModule):
930931
type of normalization applied.
931932
activation : str, default = 'gelu'
932933
activation function used.
933-
Options: 'gelu', 'geglu', 'relu', 'reglu', 'squared_relu', 'swiglu'.
934+
Options: 'gelu', 'geglu', 'relu', 'reglu', 'squared_relu', 'swiglu', 'qgelu'.
934935
init_method : Callable, default = `None`
935936
used for initializing FC1 weights in the following way: `init_method(weight)`.
936937
When set to `None`, defaults to `torch.nn.init.normal_(mean=0.0, std=0.023)`.
@@ -1097,7 +1098,7 @@ def __init__(
10971098
self.layer_norm_bias = None
10981099
self.reset_layer_norm_parameters()
10991100

1100-
if self.activation in ['reglu', 'geglu', 'swiglu']:
1101+
if self.activation in ['reglu', 'geglu', 'swiglu', 'qgelu']:
11011102
fc1_output_features = 2 * self.size_per_partition
11021103
else:
11031104
fc1_output_features = self.size_per_partition

0 commit comments

Comments
 (0)