Skip to content

Commit

Permalink
[Paddle] Add RMSNorm, RoPE and SwiGLU (#599)
Browse files Browse the repository at this point in the history
* use separate qkv

Signed-off-by: jaywan <[email protected]>

add support for GQA

Signed-off-by: jaywan <[email protected]>

minor changes

Signed-off-by: Shijie Wang <[email protected]>

change rtol

Signed-off-by: Shijie Wang <[email protected]>

fix reshape issue

Signed-off-by: Shijie Wang <[email protected]>

add rmsnorm and rotary position embedding

Signed-off-by: Shijie Wang <[email protected]>

update rmsnorm

Signed-off-by: Shijie Wang <[email protected]>

refactor layernorm and rmsnorm

Signed-off-by: Shijie Wang <[email protected]>

support swiglu

Signed-off-by: Shijie Wang <[email protected]>

add fused rope

Signed-off-by: Shijie Wang <[email protected]>

minor changes

Signed-off-by: Shijie Wang <[email protected]>

add rope api to __init__

Signed-off-by: Shijie Wang <[email protected]>

minor changes

Signed-off-by: Shijie Wang <[email protected]>

fix fp8 dtype issue

Signed-off-by: Shijie Wang <[email protected]>

* simplify ut cases

Signed-off-by: jaywan <[email protected]>

* Update transformer_engine/paddle/layer/attention.py

Co-authored-by: Tim Moon <[email protected]>
Signed-off-by: Shijie <[email protected]>

* fix name issue

Signed-off-by: Shijie Wang <[email protected]>

---------

Signed-off-by: Shijie Wang <[email protected]>
Signed-off-by: jaywan <[email protected]>
Signed-off-by: Shijie <[email protected]>
Co-authored-by: Tim Moon <[email protected]>
  • Loading branch information
Wong4j and timmoon10 authored Feb 21, 2024
1 parent 2187a8f commit 7172509
Show file tree
Hide file tree
Showing 12 changed files with 881 additions and 237 deletions.
153 changes: 101 additions & 52 deletions tests/paddle/test_layers.py

Large diffs are not rendered by default.

65 changes: 56 additions & 9 deletions tests/paddle/test_operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,6 @@

import struct

from utils import (
assert_allclose,
create_fp8_meta,
get_fused_attention_backend,
is_fused_attention_supported,
)
import numpy as np
import paddle
import paddle.nn.functional as F
Expand All @@ -34,6 +28,10 @@
cast_transpose_bgrad,
te_gelu,
gelu_fp8,
swiglu,
swiglu_fp8,
swiglu_pd,
dswiglu,
dgelu_cast_transpose_bgrad_fp8,
layernorm_fwd_fp8,
layernorm_fwd,
Expand Down Expand Up @@ -62,9 +60,9 @@
(16384, 1024, 1024)]
is_fp8_supported, reason = is_fp8_available()

SELF_ATTN_CASES = [(32, 512, 16, 64), (32, 128, 16, 64)]
CROSS_ATTN_CASES = [(32, 128, 512, 16, 64)]
FLASH_ATTN_CASES = [(4, 1024, 16, 64), (2, 2048, 16, 128)]
SELF_ATTN_CASES = [(2, 512, 12, 64)]
CROSS_ATTN_CASES = [(2, 128, 512, 12, 64)]
FLASH_ATTN_CASES = [(2, 1024, 16, 64), (2, 2048, 16, 128)]
ATTN_DTYPES = [tex.DType.kFloat16, tex.DType.kBFloat16]


Expand Down Expand Up @@ -296,6 +294,55 @@ def test_gelu_bwd_fp8(fp8_dtype):
assert_allclose(x_grad_t, x.grad.T, rtol=0.1, atol=0.01)
assert_allclose(dbias, x.grad.sum(axis=0), rtol=0.1, atol=0.01)

@staticmethod
def test_swiglu_bf16():
"""
Test BF16 SwiGLU Forward
"""
a = paddle.rand(shape=(16, 32), dtype='bfloat16') * 2 - 1
swiglu_out = swiglu(a, otype=tex.DType.kBFloat16)
swiglu_ref = swiglu_pd(a)

assert_allclose(swiglu_out, swiglu_ref, rtol=1e-2)

@staticmethod
@pytest.mark.skipif(not is_fp8_supported, reason=reason)
@pytest.mark.parametrize('fp8_dtype', [tex.DType.kFloat8E4M3, tex.DType.kFloat8E5M2])
def test_swiglu_fp8(fp8_dtype):
"""
Test FP8 SwiGLU Forward
"""
a = paddle.rand(shape=(16, 32), dtype='float32') * 2 - 1
fp8_meta = create_fp8_meta()

swiglu_out_fp8 = swiglu_fp8(a, fp8_meta, FP8FwdTensors.GEMM1_INPUT, otype=fp8_dtype)

swiglu_out = cast_from_fp8(swiglu_out_fp8,
fp8_meta,
FP8FwdTensors.GEMM1_INPUT,
itype=fp8_dtype,
otype=tex.DType.kFloat32)

swiglu_ref = swiglu_pd(a)

assert_allclose(swiglu_out, swiglu_ref, rtol=0.1, atol=0.01)

@staticmethod
def test_swiglu_bwd():
"""
Test SwiGLU Backward
"""
# y = SwiGLU(x), calculate ref
x = paddle.rand(shape=(16, 32), dtype='bfloat16') * 2 - 1
x.stop_gradient = False
y = swiglu_pd(x)
y_grad = paddle.rand(shape=(16, 16), dtype='bfloat16') * 2 - 1
paddle.autograd.backward([y], [y_grad], True)
# calculate fp8
x_grad = dswiglu(y_grad, x, otype=tex.DType.kBFloat16)

assert_allclose(x_grad, x.grad, rtol=0.1, atol=0.01)


class TestGemm:
"""
Expand Down
13 changes: 11 additions & 2 deletions transformer_engine/paddle/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,15 @@
"""Transformer Engine bindings for Paddle"""

from .fp8 import fp8_autocast
from .layer import (Linear, LayerNorm, LayerNormLinear, LayerNormMLP, FusedScaleMaskSoftmax,
DotProductAttention, MultiHeadAttention, TransformerLayer)
from .layer import (
Linear,
LayerNorm,
LayerNormLinear,
LayerNormMLP,
FusedScaleMaskSoftmax,
DotProductAttention,
MultiHeadAttention,
TransformerLayer,
RotaryPositionEmbedding,
)
from .recompute import recompute
60 changes: 57 additions & 3 deletions transformer_engine/paddle/cpp_extensions.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import math
from typing import Optional, Tuple, Union
import paddle
import paddle.nn.functional as F
import transformer_engine_paddle as tex
from .constants import TE_DType, FusedAttnBackend, FP8FwdTensors, FP8BwdTensors
from .fp8 import FP8TensorMeta
Expand Down Expand Up @@ -328,6 +329,56 @@ def gelu_fp8(
return out


def swiglu(
inp: paddle.Tensor,
otype: tex.DType,
) -> paddle.Tensor:
"""Non FP8 SWIGLU"""
return tex.te_swiglu(
inp,
int(otype),
)


def swiglu_pd(inp: paddle.Tensor,) -> paddle.Tensor:
"""Native SWIGLU"""
gate_out, up_out = paddle.chunk(inp, chunks=2, axis=-1)
out = F.silu(gate_out) * up_out
return out


def swiglu_fp8(
inp: paddle.Tensor,
fp8_meta_tensor: FP8TensorMeta,
fp8_tensor: Union[FP8FwdTensors, FP8BwdTensors],
otype: tex.DType,
) -> paddle.Tensor:
"""SWIGLU + FP8 cast"""
out, _, _ = tex.te_swiglu_fp8(
inp,
fp8_meta_tensor.scale,
fp8_meta_tensor.amax_history,
fp8_meta_tensor.scale_inv,
fp8_tensor.value,
int(otype),
)

return out


def dswiglu(
grad_output: paddle.Tensor,
swiglu_input: paddle.Tensor,
otype: tex.DType,
) -> paddle.Tensor:
"""dSWIGLU"""
return tex.te_dswiglu(
grad_output,
swiglu_input,
int(otype),
)


def dgelu_cast_transpose_bgrad_fp8(
grad_output: paddle.Tensor,
gelu_input: paddle.Tensor,
Expand Down Expand Up @@ -404,9 +455,10 @@ def rmsnorm_fwd(
eps: float,
otype: tex.DType,
sm_margin: int = 0,
zero_centered_gamma: bool = False,
) -> Tuple[paddle.Tensor, paddle.Tensor, paddle.Tensor]:
"""Non-FP8 RMSNorm forward"""
return tex.te_rmsnorm_fwd(inp, weight, eps, int(otype), sm_margin)
return tex.te_rmsnorm_fwd(inp, weight, eps, int(otype), sm_margin, zero_centered_gamma)


def rmsnorm_fwd_fp8(
Expand All @@ -417,12 +469,13 @@ def rmsnorm_fwd_fp8(
fp8_tensor: Union[FP8FwdTensors, FP8BwdTensors],
otype: tex.DType,
sm_margin: int = 0,
zero_centered_gamma: bool = False,
) -> Tuple[paddle.Tensor, paddle.Tensor, paddle.Tensor]:
"""RMSNorm with FP8 output"""
out, rsigma, _, _ = tex.te_rmsnorm_fwd_fp8(inp, weight, fp8_meta_tensor.scale,
fp8_meta_tensor.amax_history,
fp8_meta_tensor.scale_inv, eps, fp8_tensor.value,
int(otype), sm_margin)
int(otype), sm_margin, zero_centered_gamma)
return out, rsigma


Expand All @@ -432,9 +485,10 @@ def rmsnorm_bwd(
rsigma: paddle.Tensor,
gamma: paddle.Tensor,
sm_margin: int = 0,
zero_centered_gamma: bool = False,
) -> Tuple[paddle.Tensor, paddle.Tensor, paddle.Tensor]:
"""Non-FP8 RMSNorm backward"""
return tex.te_rmsnorm_bwd(dz, x, rsigma, gamma, sm_margin)
return tex.te_rmsnorm_bwd(dz, x, rsigma, gamma, sm_margin, zero_centered_gamma)


def mask_to_cu_seqlens(
Expand Down
101 changes: 94 additions & 7 deletions transformer_engine/paddle/csrc/custom_ops.cu
Original file line number Diff line number Diff line change
Expand Up @@ -218,6 +218,66 @@ std::vector<paddle::Tensor> te_gelu(const paddle::Tensor &input, int64_t otype)
return {output};
}

std::vector<paddle::Tensor> te_swiglu(const paddle::Tensor &input, int64_t otype) {
auto shape = GetShapeArray(input);
NVTE_CHECK(shape.size() == 2, "Expect the input to have 2 dimensions.");

size_t M = shape[0];
size_t N = shape[1];

auto output = paddle::empty({input.shape()[0], input.shape()[1] / 2},
Nvte2PaddleDType(Int2NvteDType(otype)), input.place());

auto input_cu = MakeNvteTensor(input);
auto output_cu = MakeNvteTensor(output.data(), GetShapeArray(output), Int2NvteDType(otype));

nvte_swiglu(input_cu.data(), output_cu.data(), input.stream());

return {output};
}

std::vector<paddle::Tensor> te_swiglu_fp8(const paddle::Tensor &input, const paddle::Tensor &scale,
paddle::Tensor &amax, // NOLINT
paddle::Tensor &scale_inv, // NOLINT
int64_t index, int64_t otype) {
auto shape = GetShapeArray(input);
NVTE_CHECK(shape.size() == 2, "Expect the input to have 2 dimensions.");

size_t M = shape[0];
size_t N = shape[1];

auto output = paddle::empty({input.shape()[0], input.shape()[1] / 2},
Nvte2PaddleDType(Int2NvteDType(otype)), input.place());

auto input_cu = MakeNvteTensor(input);
auto output_cu = MakeNvteTensor(
output.data(), GetShapeArray(output), Int2NvteDType(otype), GetDataPtr<float>(amax, index),
const_cast<void *>(GetDataPtr<float>(scale, index)), GetDataPtr<float>(scale_inv, index));

nvte_swiglu(input_cu.data(), output_cu.data(), input.stream());

return {output};
}

std::vector<paddle::Tensor> te_dswiglu(const paddle::Tensor &grad, const paddle::Tensor &input,
int64_t otype) {
auto shape = GetShapeArray(input);
NVTE_CHECK(shape.size() == 2, "Expect the input to have 2 dimensions.");

size_t M = shape[0];
size_t N = shape[1];

auto output = paddle::empty_like(input, Nvte2PaddleDType(Int2NvteDType(otype)), input.place());

auto input_cu = MakeNvteTensor(input.data(), {M, N}, Paddle2NvteDType(input.dtype()));
auto grad_cu = MakeNvteTensor(grad.data(), {M, N / 2}, Paddle2NvteDType(grad.dtype()));
auto output_cu = MakeNvteTensor(output.data(), {M, N}, Paddle2NvteDType(output.dtype()));

nvte_dswiglu(grad_cu.data(), input_cu.data(), output_cu.data(), input.stream());

return {output};
}

std::vector<paddle::Tensor> te_cast_transpose_bgrad_dgelu(const paddle::Tensor &grad_output,
const paddle::Tensor &gelu_input,
const paddle::Tensor &scale,
Expand Down Expand Up @@ -406,7 +466,9 @@ std::vector<paddle::Tensor> te_layernorm_bwd(const paddle::Tensor &dz, const pad

std::vector<paddle::Tensor> te_rmsnorm_fwd(const paddle::Tensor &input,
const paddle::Tensor &weight, float eps, int64_t otype,
int64_t sm_margin) {
int64_t sm_margin, bool zero_centered_gamma) {
NVTE_CHECK(zero_centered_gamma == false,
"zero_centered_gamma is not supported yet for RMSNorm.");
auto shape = GetShapeArray(input);
NVTE_CHECK(shape.size() == 2, "Expect the grad_output to have 2 dimensions.");

Expand Down Expand Up @@ -448,14 +510,16 @@ std::vector<paddle::Tensor> te_rmsnorm_fwd_fp8(const paddle::Tensor &input,
paddle::Tensor &amax, // NOLINT
paddle::Tensor &scale_inv, // NOLINT
float eps, int64_t index, int64_t otype,
int64_t sm_margin) {
int64_t sm_margin, bool zero_centered_gamma) {
NVTE_CHECK(zero_centered_gamma == false,
"zero_centered_gamma is not supported yet for RMSNorm.");
auto shape = GetShapeArray(input);
NVTE_CHECK(shape.size() == 2, "Expect the grad_output to have 2 dimensions.");

size_t N = shape[0];
size_t H = shape[1];

auto ln_out = paddle::empty_like(input, input.dtype(), input.place());
auto ln_out = paddle::empty_like(input, Nvte2PaddleDType(Int2NvteDType(otype)), input.place());
auto rsigma =
paddle::empty({static_cast<int64_t>(N)}, paddle::DataType::FLOAT32, input.place());
auto input_cu = MakeNvteTensor(input);
Expand Down Expand Up @@ -487,7 +551,10 @@ std::vector<paddle::Tensor> te_rmsnorm_fwd_fp8(const paddle::Tensor &input,

std::vector<paddle::Tensor> te_rmsnorm_bwd(const paddle::Tensor &dz, const paddle::Tensor &x,
const paddle::Tensor &rsigma,
const paddle::Tensor &gamma, int64_t sm_margin) {
const paddle::Tensor &gamma, int64_t sm_margin,
bool zero_centered_gamma) {
NVTE_CHECK(zero_centered_gamma == false,
"zero_centered_gamma is not supported yet for RMSNorm.");
auto dx = paddle::empty_like(x, x.dtype(), x.place());
auto dgamma = paddle::empty_like(gamma, gamma.dtype(), gamma.place());

Expand Down Expand Up @@ -1374,6 +1441,25 @@ PD_BUILD_OP(te_gelu)
.Attrs({"otype: int64_t"})
.SetKernelFn(PD_KERNEL(transformer_engine::paddle_ext::te_gelu));

PD_BUILD_OP(te_swiglu)
.Inputs({"Input"})
.Outputs({"Output"})
.Attrs({"otype: int64_t"})
.SetKernelFn(PD_KERNEL(transformer_engine::paddle_ext::te_swiglu));

PD_BUILD_OP(te_swiglu_fp8)
.Inputs({"Input", "Scale", "_Amax", "_ScaleInv"})
.Outputs({"Output", "Amax", "ScaleInv"})
.SetInplaceMap({{"_Amax", "Amax"}, {"_ScaleInv", "ScaleInv"}})
.Attrs({"index: int64_t", "otype: int64_t"})
.SetKernelFn(PD_KERNEL(transformer_engine::paddle_ext::te_swiglu_fp8));

PD_BUILD_OP(te_dswiglu)
.Inputs({"Grad", "Input"})
.Outputs({"Output"})
.Attrs({"otype: int64_t"})
.SetKernelFn(PD_KERNEL(transformer_engine::paddle_ext::te_dswiglu));

PD_BUILD_OP(te_cast_transpose_bgrad_dgelu)
.Inputs({"GradOutput", "GeluInput", "Scale", "_Amax", "_ScaleInv"})
.Outputs({"CastedDgelu", "TransposedDgelu", "Dbias", "Amax", "ScaleInv"})
Expand Down Expand Up @@ -1404,20 +1490,21 @@ PD_BUILD_OP(te_layernorm_bwd)
PD_BUILD_OP(te_rmsnorm_fwd)
.Inputs({"Input", "Weight"})
.Outputs({"Output", "InvVariance"})
.Attrs({"eps: float", "otype: int64_t", "sm_margin: int64_t"})
.Attrs({"eps: float", "otype: int64_t", "sm_margin: int64_t", "zero_centered_gamma: bool"})
.SetKernelFn(PD_KERNEL(transformer_engine::paddle_ext::te_rmsnorm_fwd));

PD_BUILD_OP(te_rmsnorm_fwd_fp8)
.Inputs({"Input", "Weight", "Scale", "_Amax", "_ScaleInv"})
.Outputs({"Output", "InvVariance", "Amax", "ScaleInv"})
.SetInplaceMap({{"_Amax", "Amax"}, {"_ScaleInv", "ScaleInv"}})
.Attrs({"eps: float", "index: int64_t", "otype: int64_t", "sm_margin: int64_t"})
.Attrs({"eps: float", "index: int64_t", "otype: int64_t", "sm_margin: int64_t",
"zero_centered_gamma: bool"})
.SetKernelFn(PD_KERNEL(transformer_engine::paddle_ext::te_rmsnorm_fwd_fp8));

PD_BUILD_OP(te_rmsnorm_bwd)
.Inputs({"Dz", "X", "Rsigma", "Gamma"})
.Outputs({"Dx", "Dgamma"})
.Attrs({"sm_margin: int64_t"})
.Attrs({"sm_margin: int64_t", "zero_centered_gamma: bool"})
.SetKernelFn(PD_KERNEL(transformer_engine::paddle_ext::te_rmsnorm_bwd));

PD_BUILD_OP(te_fused_attn_fwd_qkvpacked)
Expand Down
2 changes: 1 addition & 1 deletion transformer_engine/paddle/layer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
# See LICENSE for license information.
"""Layer level Paddle APIs"""

from .attention import DotProductAttention, MultiHeadAttention
from .attention import DotProductAttention, MultiHeadAttention, RotaryPositionEmbedding
from .layernorm import LayerNorm
from .layernorm_linear import LayerNormLinear
from .layernorm_mlp import LayerNormMLP
Expand Down
Loading

0 comments on commit 7172509

Please sign in to comment.