Skip to content

Commit

Permalink
[PyTorch] Activation operations (#1164)
Browse files Browse the repository at this point in the history
* Add activation ops

Signed-off-by: Tim Moon <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Fix lint warnings

Signed-off-by: Tim Moon <[email protected]>

* Fix linter warning

Signed-off-by: Tim Moon <[email protected]>

* Update to use QuantizedTensor

Signed-off-by: Tim Moon <[email protected]>

* Respect PyTorch autograd dtype

Signed-off-by: Tim Moon <[email protected]>

* Rename CastFloat8 op to Quantize

Signed-off-by: Tim Moon <[email protected]>

* Add support for fused dSwiGLU-cast-transpose

Signed-off-by: Tim Moon <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Signed-off-by: Tim Moon <[email protected]>
Signed-off-by: Tim Moon <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
timmoon10 and pre-commit-ci[bot] authored Nov 15, 2024
1 parent d1488e7 commit 20b0473
Show file tree
Hide file tree
Showing 7 changed files with 671 additions and 0 deletions.
160 changes: 160 additions & 0 deletions tests/pytorch/test_fusible_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -1362,6 +1362,166 @@ def test_make_extra_output(
torch.testing.assert_close(y2_test, y2_ref, rtol=0, atol=0)
torch.testing.assert_close(dx_test, x_ref.grad, **tols)

@pytest.mark.parametrize("activation", ("relu", "gelu", "geglu", "reglu", "swiglu"))
@pytest.mark.parametrize("out_shape", ((37,), (2, 13), (4, 1, 16)))
@pytest.mark.parametrize("dtype", _dtypes)
@pytest.mark.parametrize("fp8_input", (False, True))
@pytest.mark.parametrize("fp8_output", (False, True))
def test_activation(
self,
*,
activation: str,
out_shape: Iterable[int],
dtype: torch.dtype,
device: torch.device = "cuda",
fp8_input: bool,
fp8_output: bool,
) -> None:
"""Activation functions"""

# Tensor dimensions
in_shape = list(out_shape)
if activation in ("geglu", "reglu", "swiglu"):
in_shape[-1] *= 2

# Skip invalid configurations
if fp8_input or fp8_output:
if not fp8_available:
pytest.skip(reason_for_no_fp8)
if torch.device(device).type != "cuda":
pytest.skip("FP8 is only supported on CUDA devices")

# Random data
x_ref, x_test = make_reference_and_test_tensors(
in_shape,
test_dtype=dtype,
test_device=device,
test_is_fp8=fp8_input,
)
dy_ref, dy_test = make_reference_and_test_tensors(
out_shape,
test_dtype=dtype,
test_device=device,
requires_grad=False,
)

# Plain PyTorch implementation
y_ref: torch.Tensor
if activation == "gelu":
y_ref = torch.nn.functional.gelu(x_ref, approximate="tanh")
elif activation == "relu":
y_ref = torch.nn.functional.relu(x_ref)
elif activation == "geglu":
x1, x2 = x_ref.chunk(2, dim=-1)
y_ref = torch.nn.functional.gelu(x1, approximate="tanh") * x2
elif activation == "reglu":
x1, x2 = x_ref.chunk(2, dim=-1)
y_ref = torch.nn.functional.relu(x1) * x2
elif activation == "swiglu":
x1, x2 = x_ref.chunk(2, dim=-1)
y_ref = torch.nn.functional.silu(x1) * x2
else:
raise ValueError(f"Unexpected activation function ({activation})")
y_ref.backward(dy_ref)

# Implementation with fusible operation
make_op = dict(
gelu=te_ops.GELU,
relu=te_ops.ReLU,
geglu=te_ops.GEGLU,
reglu=te_ops.ReGLU,
swiglu=te_ops.SwiGLU,
)[activation]
forward = te_ops.Sequential(
make_op(),
te_ops.Quantize(forward=fp8_output, backward=False),
)
with te.fp8_autocast(enabled=fp8_output):
y_test = forward(x_test)
y_test.backward(dy_test)

# Expected numerical error
tols = dtype_tols(dtype)
if fp8_output:
tols = dtype_tols(tex.DType.kFloat8E4M3)

# Check results
y_test = y_test.to(dtype=torch.float64, device="cpu")
dx_test = x_test.grad.to(dtype=torch.float64, device="cpu")
torch.testing.assert_close(y_test, y_ref, **tols)
torch.testing.assert_close(dx_test, x_ref.grad, **tols)

@pytest.mark.parametrize("dtype", _dtypes)
@pytest.mark.parametrize("fp8_output", (False, True))
@pytest.mark.parametrize("fp8_grad_input", (False, True))
def test_swiglu(
self,
*,
out_shape: Iterable[int] = (16, 16),
dtype: torch.dtype,
device: torch.device = "cuda",
fp8_output: bool,
fp8_grad_input: bool,
):

# Tensor dimensions
in_shape = list(out_shape)
in_shape[-1] *= 2

# Skip invalid configurations
fp8 = fp8_output or fp8_grad_input
if fp8:
if not fp8_available:
pytest.skip(reason_for_no_fp8)
if torch.device(device).type != "cuda":
pytest.skip("FP8 is only supported on CUDA devices")

# FP8 recipe
fp8_recipe = None
if fp8_grad_input:
fp8_recipe = transformer_engine.common.recipe.DelayedScaling(
fp8_format=transformer_engine.common.recipe.Format.E4M3,
)

# Random data
x_ref, x_test = make_reference_and_test_tensors(
in_shape,
test_dtype=dtype,
test_device=device,
)
dy_ref, dy_test = make_reference_and_test_tensors(
out_shape,
test_dtype=dtype,
test_device=device,
requires_grad=False,
)

# Plain PyTorch implementation
x1, x2 = x_ref.chunk(2, dim=-1)
y_ref = torch.nn.functional.silu(x1) * x2
y_ref.backward(dy_ref)

# Implementation with fusible operation
forward = te_ops.Sequential(
te_ops.Quantize(forward=False, backward=fp8_grad_input),
te_ops.SwiGLU(),
te_ops.Quantize(forward=fp8_output, backward=False),
)
with te.fp8_autocast(enabled=fp8, fp8_recipe=fp8_recipe):
y_test = forward(x_test)
y_test.backward(dy_test)

# Expected numerical error
tols = dtype_tols(dtype)
if fp8:
tols = dtype_tols(tex.DType.kFloat8E4M3)

# Check results
y_test = y_test.to(dtype=torch.float64, device="cpu")
dx_test = x_test.grad.to(dtype=torch.float64, device="cpu")
torch.testing.assert_close(y_test, y_ref, **tols)
torch.testing.assert_close(dx_test, x_ref.grad, **tols)


class TestFusedOps:
"""Tests for fused operations"""
Expand Down
39 changes: 39 additions & 0 deletions transformer_engine/pytorch/cpp_extensions/transpose.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
"fp8_cast_transpose_fused",
"fp8_cast_transpose_bgrad_fused",
"fp8_cast_transpose_bgrad_dgelu_fused",
"fp8_dswiglu_cast_transpose_fused",
"fp8_multi_cast_transpose_fused",
"fp8_transpose_bgrad_fused",
]
Expand Down Expand Up @@ -168,6 +169,44 @@ def fp8_cast_transpose_bgrad_dgelu_fused(
)


def fp8_dswiglu_cast_transpose_fused(
grad_output: torch.Tensor,
inp: torch.Tensor,
*,
grad_input: torch.Tensor,
grad_input_transpose: torch.Tensor,
otype: tex.DType,
fp8_meta: Optional[tex.FP8TensorMeta] = None,
fp8_meta_index: Union[tex.FP8FwdTensors, tex.FP8BwdTensors, None] = None,
scale: Optional[torch.Tensor] = None,
amax: Optional[torch.Tensor] = None,
scale_inv: Optional[torch.Tensor] = None,
) -> None:
"""Fused SwiGLU backward + FP8 cast + FP8 transpose"""

# Get FP8 scaling factors
fp8_scales, fp8_scales_offsets = canonicalize_fp8_scales(
scale=scale,
amax=amax,
scale_inv=scale_inv,
fp8_meta=fp8_meta,
fp8_meta_index=fp8_meta_index,
)

# Launch kernel
return tex.fused_dswiglu_cast_transpose(
grad_output,
inp,
grad_input,
grad_input_transpose,
fp8_scales["scale"],
fp8_scales["amax"],
fp8_scales["scale_inv"],
otype,
**fp8_scales_offsets,
)


def fp8_multi_cast_transpose_fused(
input_list: List[torch.Tensor],
fp8_meta_tensor: tex.FP8TensorMeta,
Expand Down
6 changes: 6 additions & 0 deletions transformer_engine/pytorch/csrc/extensions.h
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,12 @@ std::vector<at::Tensor> fused_cast_transpose_bgrad_dgelu(at::Tensor grad_output,
int scale_offset = 0, int amax_offset = 0,
int scale_inv_offset = 0);

void fused_dswiglu_cast_transpose(at::Tensor grad_output, at::Tensor input, at::Tensor grad_input,
at::Tensor grad_input_transpose, at::Tensor scale,
at::Tensor amax, at::Tensor scale_inv,
transformer_engine::DType otype, int scale_offset = 0,
int amax_offset = 0, int scale_inv_offset = 0);

void fused_multi_cast_transpose(std::vector<at::Tensor> input_list,
std::vector<at::Tensor> scale_list,
std::vector<at::Tensor> cast_output_list,
Expand Down
6 changes: 6 additions & 0 deletions transformer_engine/pytorch/csrc/extensions/pybind.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,12 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
py::arg("grad_output"), py::arg("gelu_input"), py::arg("scale"), py::arg("amax"),
py::arg("scale_inv"), py::arg("otype"), py::arg("scale_offset") = 0,
py::arg("amax_offset") = 0, py::arg("scale_inv_offset") = 0);
m.def("fused_dswiglu_cast_transpose", &fused_dswiglu_cast_transpose,
"Fused SwiGLU backward + FP8 cast + FP8 transpose",
py::call_guard<py::gil_scoped_release>(), py::arg("grad_output"), py::arg("input"),
py::arg("grad_input"), py::arg("grad_input_transpose"), py::arg("scale"), py::arg("amax"),
py::arg("scale_inv"), py::arg("otype"), py::arg("scale_offset") = 0,
py::arg("amax_offset") = 0, py::arg("scale_inv_offset") = 0);
m.def("fused_multi_cast_transpose", &fused_multi_cast_transpose,
"Fused Multi-tensor Cast + Transpose", py::call_guard<py::gil_scoped_release>());
m.def("fused_multi_cast_transpose_alloc", &fused_multi_cast_transpose_alloc,
Expand Down
69 changes: 69 additions & 0 deletions transformer_engine/pytorch/csrc/extensions/transpose.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,75 @@ std::vector<at::Tensor> fused_cast_transpose_bgrad_dgelu(at::Tensor grad_output,
return {grad_bias, dgelu, dgelu_transpose};
}

void fused_dswiglu_cast_transpose(at::Tensor grad_output, at::Tensor input, at::Tensor grad_input,
at::Tensor grad_input_transpose, at::Tensor scale,
at::Tensor amax, at::Tensor scale_inv,
transformer_engine::DType otype, int scale_offset,
int amax_offset, int scale_inv_offset) {
using namespace transformer_engine;

// Tensor dimensions
auto outer_dim = [](const at::Tensor& tensor) -> size_t {
return tensor.numel() / tensor.size(-1);
};
const auto M = outer_dim(grad_output);
const auto N = static_cast<size_t>(grad_output.size(-1));

// Check tensor dims
NVTE_CHECK(grad_output.dim() == 2, "Expected grad output tensor to have 2 dims, but found ",
grad_output.dim());
NVTE_CHECK(input.dim() == 2, "Expected input tensor to have 2 dims, but found ", input.dim());
NVTE_CHECK(outer_dim(input) == M, "Expected input tensor to have outer dimension of ", M,
", but found ", outer_dim(input));
NVTE_CHECK(input.size(-1) == 2 * N, "Expected input tensor to have inner dimension of ", 2 * N,
", but found ", input.size(-1));
NVTE_CHECK(grad_input.dim() == 2, "Expected grad input tensor to have 2 dims, but found ",
grad_input.dim());
NVTE_CHECK(outer_dim(grad_input) == M, "Expected grad input tensor to have outer dimension of ",
M, ", but found ", outer_dim(grad_input));
NVTE_CHECK(grad_input.size(-1) == 2 * N, "Expected grad input tensor to have inner dimension of ",
2 * N, ", but found ", grad_input.size(-1));
NVTE_CHECK(grad_input_transpose.dim() == 2,
"Expected grad input transpose tensor to have 2 dims, but found ",
grad_input_transpose.dim());
NVTE_CHECK(grad_input_transpose.size(0) == 2 * N,
"Expected grad input tensor to have outer dimension of ", 2 * N, ", but found ",
grad_input_transpose.size(0));
NVTE_CHECK(grad_input_transpose.size(1) == M,
"Expected grad input tensor to have outer dimension of ", M, ", but found ",
grad_input_transpose.size(1));

// Check tensor format
NVTE_CHECK(grad_output.is_contiguous(), "Expected grad output tensor to be contiguous");
NVTE_CHECK(input.is_contiguous(), "Expected input tensor to be contiguous");
NVTE_CHECK(grad_input.is_contiguous(), "Expected grad input tensor to be contiguous");
NVTE_CHECK(grad_input_transpose.is_contiguous(),
"Expected grad input transpose tensor to be contiguous");
NVTE_CHECK(grad_output.scalar_type() == input.scalar_type(),
"Expected grad output tensor and input tensor to have same dtype");
NVTE_CHECK(grad_input.scalar_type() == at::ScalarType::Byte,
"Expected grad input tensor to be uint8 buffer");
NVTE_CHECK(grad_input_transpose.scalar_type() == at::ScalarType::Byte,
"Expected grad input transpose tensor to be uint8 buffer");

// Get pointers for FP8 scale, amax, scale-inverse
void* scale_dptr = getDataPtr(scale, scale_offset);
void* amax_dptr = getDataPtr(amax, amax_offset);
void* scale_inv_dptr = getDataPtr(scale_inv, scale_inv_offset);

// Construct Transformer Engine tensors
auto dy_cu = makeTransformerEngineTensor(grad_output);
auto x_cu = makeTransformerEngineTensor(input);
auto dx_cu = makeTransformerEngineTensor(grad_input.data_ptr(), {M, 2 * N}, otype, amax_dptr,
scale_dptr, scale_inv_dptr);
auto dx_t_cu = makeTransformerEngineTensor(grad_input_transpose.data_ptr(), {2 * N, M}, otype,
amax_dptr, scale_dptr, scale_inv_dptr);

// Launch kernel
nvte_dswiglu_cast_transpose(dy_cu.data(), x_cu.data(), dx_cu.data(), dx_t_cu.data(),
at::cuda::getCurrentCUDAStream());
}

void fused_multi_cast_transpose_base(std::vector<at::Tensor> input_list,
std::vector<void*> scale_dptr_list,
std::vector<at::Tensor> cast_output_list,
Expand Down
1 change: 1 addition & 0 deletions transformer_engine/pytorch/ops/basic/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

"""Single tensor operations supported by the operation fuser."""

from .activation import GELU, ReLU, GEGLU, ReGLU, SwiGLU
from .add_in_place import AddInPlace
from .all_gather import AllGather
from .all_reduce import AllReduce
Expand Down
Loading

0 comments on commit 20b0473

Please sign in to comment.