Skip to content

Commit

Permalink
PyTorch MultiheadAttention API (#387)
Browse files Browse the repository at this point in the history
* PyTorch MultiheadAttention API

Signed-off-by: Kirthi Shankar Sivamani <[email protected]>

* Fix ONNX export tests

Signed-off-by: Kirthi Shankar Sivamani <[email protected]>

* Expose MultiheadAttention for import

Signed-off-by: Kirthi Shankar Sivamani <[email protected]>

* Expand mask type and add no mask numerical test

Signed-off-by: Kirthi Shankar Sivamani <[email protected]>

---------

Signed-off-by: Kirthi Shankar Sivamani <[email protected]>
  • Loading branch information
ksivaman authored Aug 19, 2023
1 parent f29efb7 commit 8aa2da1
Show file tree
Hide file tree
Showing 6 changed files with 288 additions and 27 deletions.
3 changes: 3 additions & 0 deletions docs/api/pytorch.rst
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,9 @@ pyTorch
.. autoapiclass:: transformer_engine.pytorch.DotProductAttention(num_attention_heads, kv_channels, **kwargs)
:members: forward

.. autoapiclass:: transformer_engine.pytorch.MultiheadAttention(hidden_size, num_attention_heads, **kwargs)
:members: forward

.. autoapiclass:: transformer_engine.pytorch.TransformerLayer(hidden_size, ffn_hidden_size, num_attention_heads, **kwargs)
:members: forward

Expand Down
87 changes: 84 additions & 3 deletions tests/pytorch/test_numerics.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,8 @@
attention_mask_func,
)
from transformer_engine.pytorch import (
DotProductAttention, Linear, LayerNormLinear, LayerNormMLP, TransformerLayer, RMSNorm
DotProductAttention, LayerNormLinear, LayerNormMLP, Linear,
MultiheadAttention, RMSNorm, TransformerLayer
)
from transformer_engine.pytorch.distributed import checkpoint as te_checkpoint

Expand Down Expand Up @@ -60,6 +61,9 @@ def __init__(self, hidden_size, eps, num_attention_heads, embed, num_layers, seq

all_normalizations = ["LayerNorm", "RMSNorm"]

mask_types = ["causal", "no_mask"]


def get_causal_attn_mask(sq: int) -> torch.Tensor:
return torch.triu(torch.ones(sq, sq, device="cuda"), diagonal=1).bool()

Expand Down Expand Up @@ -320,6 +324,7 @@ def forward(

return context_layer


# Adapted from https://github.com/bzhangGo/rmsnorm/blob/c6691f20ec0af4128c8159c903071f7575404295/rmsnorm_torch.py
class TorchRMSNorm(nn.Module):
def __init__(self, in_features, eps=1e-5):
Expand All @@ -341,6 +346,7 @@ def forward(self, x):

return (self.weight.float() * x_normed).to(x.dtype)


class TorchLayerNormLinear(nn.Module):
def __init__(self, in_features: int, out_features: int,
eps: float, bias: bool = True,
Expand Down Expand Up @@ -371,14 +377,19 @@ def __init__(self, hidden_size: int, num_attention_heads: int):
)

def forward(self, x, attn_mask=None):
return self.mhsa(x, x, x, attn_mask=attn_mask, need_weights=False)
output = self.mhsa(x, x, x, attn_mask=attn_mask, need_weights=False)
if isinstance(output, tuple):
output = output[0]
return output


_supported_act = {'geglu' : nn.GELU(approximate="tanh"),
'gelu' : nn.GELU(approximate="tanh"),
'reglu' : nn.ReLU(),
'relu' : nn.ReLU(),
'swiglu' : nn.SiLU()}


class TorchGLU(nn.Module):
def __init__(self, activation: str):
super().__init__()
Expand All @@ -391,6 +402,7 @@ def forward(self, x):
a = self.act(a)
return a * b


class TorchLayerNormMLP(nn.Module):
def __init__(self, hidden_size: int, ffn_hidden_size: int,
eps: float = 1e-5, activation = 'gelu',
Expand Down Expand Up @@ -431,7 +443,7 @@ def forward(
attn_mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
a = self.ln(x)
b, _ = self.causal_attn(a, attn_mask)
b = self.causal_attn(a, attn_mask)
x = x + self.resid_attn_dropout(b)
n = self.ln_mlp(x)
x = x + self.resid_mlp_dropout(n)
Expand Down Expand Up @@ -754,6 +766,75 @@ def test_gpt_accuracy(dtype, bs, model):
assert_allclose(te_outputs[0], torch_outputs[0], 5e-2)


def _test_mha_accuracy(block, bs, dtype, config, mask_type):
reset_rng_states()

inp_hidden_states = torch.randn(
config.seq_len, bs, config.hidden_size, dtype=dtype, requires_grad=True
).cuda()
inp_hidden_states.retain_grad()
inp_attn_mask = get_causal_attn_mask(config.seq_len) if mask_type == "causal" else None

out = block(inp_hidden_states, inp_attn_mask)
loss = out.sum()
loss.backward()

torch.cuda.synchronize()
outputs = [out, inp_hidden_states.grad]
for p in block.parameters():
if p.requires_grad:
outputs.append(p.grad)
return outputs


@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("bs", batch_sizes)
@pytest.mark.parametrize("model", model_configs.keys())
@pytest.mark.parametrize("mask_type", mask_types)
def test_mha_accuracy(dtype, bs, model, mask_type):
config = model_configs[model]

te_mha = (
MultiheadAttention(
config.hidden_size,
config.num_attention_heads,
fuse_qkv_params=True,
qkv_weight_interleaved=False,
input_layernorm=False,
attn_mask_type=mask_type,
)
.to(dtype=dtype)
.cuda()
.eval()
)

torch_mha = (
TorchMHA(
config.hidden_size,
config.num_attention_heads,
)
.to(dtype=dtype)
.cuda()
.eval()
)

# Share params
with torch.no_grad():
torch_mha.mhsa.in_proj_weight = Parameter(te_mha.qkv.weight.clone())
torch_mha.mhsa.in_proj_bias = Parameter(te_mha.qkv.bias.clone())
torch_mha.mhsa.out_proj.weight = Parameter(te_mha.proj.weight.clone())
torch_mha.mhsa.out_proj.bias = Parameter(te_mha.proj.bias.clone())

te_outputs = _test_mha_accuracy(te_mha, bs, dtype, config, mask_type)
torch_outputs = _test_mha_accuracy(torch_mha, bs, dtype, config, mask_type)

# Check output.
if dtype == torch.float32:
assert_allclose(te_outputs[0], torch_outputs[0], 5e-3)
else:
assert_allclose(te_outputs[0], torch_outputs[0], 5e-2)


def _test_granular_accuracy(block, bs, dtype, config):
reset_rng_states()

Expand Down
3 changes: 2 additions & 1 deletion tests/pytorch/test_onnx_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -1267,14 +1267,15 @@ def test_export_multihead_attention(
input_ln_str = "_input-ln" if input_layernorm else ""
fname = f"te.multihead_attention{fp8_str}{attn_mask_str}{attn_type_str}{input_ln_str}{fuse_qkv_str}{dtype_str}.onnx"

model = te.attention.MultiHeadAttention(
model = te.MultiheadAttention(
*attention_args,
attn_mask_type=attn_mask_type,
params_dtype=precision,
return_layernorm_output=return_layernorm_output,
input_layernorm=input_layernorm,
attention_type=attention_type,
fuse_qkv_params=fuse_qkv_params,
return_bias=True,
).to(device='cuda')

inp_context = (hidden_states_context, attention_mask, encoder_output)
Expand Down
1 change: 1 addition & 0 deletions transformer_engine/pytorch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from .module import LayerNorm
from .module import RMSNorm
from .attention import DotProductAttention
from .attention import MultiheadAttention
from .transformer import TransformerLayer
from .fp8 import fp8_autocast
from .export import onnx_export
Expand Down
Loading

0 comments on commit 8aa2da1

Please sign in to comment.