Skip to content

Commit

Permalink
[PyTorch] Adjusted the logic of MHA and DPA to enable speculative dec…
Browse files Browse the repository at this point in the history
…oding (#668)

* Modified MHA and DPA logic to use causal softmax and FA for inference

Signed-off-by: Oleg Goncharov <[email protected]>

* Adjusted unfused attention and softmax logic for inference

Signed-off-by: Oleg Goncharov <[email protected]>

* Cleaned up the code per pylint

Signed-off-by: Oleg Goncharov <[email protected]>

* Added test cases to evaluate numerics of incremental decoding

Signed-off-by: Oleg Goncharov <[email protected]>

* Apply suggestions from code review

Co-authored-by: Tim Moon <[email protected]>
Signed-off-by: Oleg Goncharov <[email protected]>

* Apply suggestions from code review [sequence start-end]

Co-authored-by: Tim Moon <[email protected]>
Signed-off-by: Oleg Goncharov <[email protected]>

* Apply suggestions from code review [inference_params offset update]]

Co-authored-by: Tim Moon <[email protected]>
Signed-off-by: Oleg Goncharov <[email protected]>

* Fixed bug in KV-cache indices and updated test suite

Signed-off-by: Oleg Goncharov <[email protected]>

* Added inference_params description and applied suggestions from the code review

Signed-off-by: Oleg Goncharov <[email protected]>

* Adjusted absolute tolerances in numerics tests

Signed-off-by: Oleg Goncharov <[email protected]>

* Cleaned up the files per pylint

Signed-off-by: Oleg Goncharov <[email protected]>

---------

Signed-off-by: Oleg Goncharov <[email protected]>
Signed-off-by: Oleg Goncharov <[email protected]>
Co-authored-by: Przemyslaw Tredak <[email protected]>
Co-authored-by: Tim Moon <[email protected]>
Co-authored-by: Kirthi Shankar Sivamani <[email protected]>
  • Loading branch information
4 people authored Mar 6, 2024
1 parent 728e335 commit b459ccc
Show file tree
Hide file tree
Showing 3 changed files with 243 additions and 85 deletions.
117 changes: 116 additions & 1 deletion tests/pytorch/test_numerics.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
)
from transformer_engine.pytorch import (
DotProductAttention, LayerNormLinear, LayerNormMLP, Linear,
MultiheadAttention, RMSNorm, TransformerLayer, LayerNorm
MultiheadAttention, RMSNorm, TransformerLayer, LayerNorm, InferenceParams
)
from transformer_engine.pytorch.distributed import checkpoint as te_checkpoint
from transformer_engine.pytorch.distributed import _set_cuda_rng_state, CudaRNGStatesTracker
Expand Down Expand Up @@ -1397,3 +1397,118 @@ def test_transformer_layer_hidden_states_format(dtype, bs, model):
y_bshd = block_bshd(x_bshd)

assert_all_equal([y_bshd], [y_sbhd.transpose(0,1).contiguous()])


model_configs_inference = {
# hidden_size, eps, num_attention_heads, embed, num_layers, seq_len
"126m": ModelConfig(768, 1e-5, 12, 64, 12, 16),
}
backends_inference = ["FlashAttention", "UnfusedAttention"]
module_inference = ["TransformerLayer", "MultiheadAttention"]
input_formats_inference = ["sbhd", "bshd"]

@pytest.mark.parametrize("dtype", param_types)
@pytest.mark.parametrize("bs", batch_sizes)
@pytest.mark.parametrize("model_key", model_configs_inference.keys())
@pytest.mark.parametrize("use_RoPE", all_boolean)
@pytest.mark.parametrize("input_format", input_formats_inference)
@pytest.mark.parametrize("module", module_inference)
@pytest.mark.parametrize("backend", backends_inference)
def test_kv_cache_accuracy(dtype, bs, model_key, use_RoPE, input_format, module, backend):
os.environ["NVTE_FLASH_ATTN"] = "0"
os.environ["NVTE_FUSED_ATTN"] = "0"

if backend == "FlashAttention":
os.environ["NVTE_FLASH_ATTN"] = "1"
elif backend == "FusedAttention":
os.environ["NVTE_FUSED_ATTN"] = "1"

config = model_configs_inference[model_key]

S = config.seq_len
B = bs
H = config.num_attention_heads
D = config.hidden_size
head_size = config.embed
layer_number = 1

# Limits the max size of KV-cache
B_max = B
S_max = S + 2

if module == "TransformerLayer":
model = (
TransformerLayer(
hidden_size=D,
ffn_hidden_size= 4 * D,
num_attention_heads=H,
attn_input_format=input_format,
layer_number=layer_number,
attention_dropout = 0.0
)
.to(dtype=dtype)
.cuda()
.eval()
)
else:
model = (
MultiheadAttention(
hidden_size=D,
num_attention_heads=H,
qkv_format=input_format,
layer_number=layer_number,
attention_dropout = 0.0
)
.to(dtype=dtype)
.cuda()
.eval()
)

inference_params = InferenceParams(max_batch_size=B_max, max_sequence_length=S_max)
rotary_freqs = torch.randn((S_max, 1, 1, head_size), dtype=torch.float, device="cuda")

input = torch.randn((S, B, D), dtype=dtype, device="cuda")
if input_format == "bshd":
input = input.transpose(0, 1).contiguous()

incremental_output = torch.zeros_like(input)

# Generate output for the entire sequence
full_output = model(
hidden_states=input,
rotary_pos_emb=rotary_freqs if use_RoPE else None)

# Incrementaly generate outputs using KV-cache
for i in range(S):
if input_format == "sbhd":
incremental_input = input[i].view(1,B,D)
else:
incremental_input = input[:, i, :].view(B,1,D)

line_output = model(
hidden_states=incremental_input,
inference_params=inference_params,
rotary_pos_emb=rotary_freqs if use_RoPE else None)

inference_params.sequence_len_offset += 1

if input_format == "sbhd":
incremental_output[i] = line_output.view(B,D)
else:
incremental_output[:, i, :] = line_output.view(B,D)

if module == "TransformerLayer":
atol = {
torch.float32 : 5e-3,
torch.half : 5e-3,
torch.bfloat16: 5e-2,
}
else:
atol = {
torch.float32 : 1e-3,
torch.half : 1e-3,
torch.bfloat16: 1e-2,
}

# Check if the fully generated output matches the one generated incrementally
assert_allclose(full_output, incremental_output, atol[dtype])
135 changes: 86 additions & 49 deletions transformer_engine/pytorch/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,6 @@

__all__ = ["DotProductAttention", "InferenceParams", "MultiheadAttention"]


class InferenceParams: # pylint: disable=too-few-public-methods
"""
Inference parameters that are passed to the main model in order
Expand Down Expand Up @@ -1180,7 +1179,7 @@ def apply_rotary_pos_emb(
Parameters
----------
t: torch.Tensor
Input tensor of shape `[s, b, h, d]`, `[s, b, h, d]` or `[t, h, d]`, on which
Input tensor of shape `[s, b, h, d]`, `[b, s, h, d]` or `[t, h, d]`, on which
rotary positional embedding will be applied.
freqs: torch.Tensor
Rotary positional embedding tensor of shape `[s2, 1, 1, d2]` and dtype 'float',
Expand Down Expand Up @@ -2523,6 +2522,7 @@ def forward(
core_attention_bias: Optional[torch.Tensor] = None,
alibi_slopes: Optional[torch.Tensor] = None,
fast_zero_fill: bool = True,
inference_params: Optional[InferenceParams] = None,
) -> torch.Tensor:
"""
Dot Product Attention Layer.
Expand Down Expand Up @@ -2616,6 +2616,16 @@ def forward(
to the attention score of query i and key j.
fast_zero_fill: bool, default = `True`
Whether to use the fast path to set output tensors to 0 or not.
inference_params: Optional[InferenceParams], default = `None`
Optimizes execution performance during inference by caching Keys and Values of the
current decoding iteration. These cached values are appended to the K and V values
computed in previous iterations, eliminating the need to recalculate them for the
entire sequence.
Initialization of `inference_params` is required prior to use to ensure sufficient
memory allocation.
Adjustments of the sequence_len_offset should be done after a complete forward pass.
If rotary positional embeddings (RoPE) are utilized, they must be prepared beforehand.
Supports "sbhd" and "bshd" layouts, with the "sbhd" layout being more efficient.
"""

assert (
Expand Down Expand Up @@ -2643,6 +2653,39 @@ def forward(
if qkv_format is None:
qkv_format = self.qkv_format

if inference_params is not None:
assert self.layer_number is not None, "Layer number must be set!"

if qkv_format == "bshd":
key_layer = key_layer.transpose(0, 1)
value_layer = value_layer.transpose(0, 1)

(inference_key_memory, inference_value_memory,
) = inference_params.key_value_memory_dict[self.layer_number]

batch_start = inference_params.batch_size_offset
batch_end = batch_start + key_layer.size(1)
assert batch_end <= inference_key_memory.size(1)

sequence_start = inference_params.sequence_len_offset
sequence_end = sequence_start + key_layer.size(0)
assert sequence_end <= inference_key_memory.size(0)

# Copy keys and values into KV-cache
inference_key_memory[
sequence_start:sequence_end, batch_start:batch_end, ...] = key_layer
inference_value_memory[
sequence_start:sequence_end, batch_start:batch_end, ...] = value_layer
key_layer = inference_key_memory[:sequence_end, batch_start:batch_end, ...]
value_layer = inference_value_memory[:sequence_end, batch_start:batch_end, ...]

if qkv_format == "bshd":
key_layer = key_layer.transpose(0, 1)
value_layer = value_layer.transpose(0, 1)

key_layer = key_layer.contiguous()
value_layer = value_layer.contiguous()

assert (key_layer.shape[-2] == self.num_gqa_groups_per_partition
and value_layer.shape[-2] == self.num_gqa_groups_per_partition
), f"Keys and values must have num_gqa_group = {self.num_gqa_groups} heads!"
Expand Down Expand Up @@ -2721,12 +2764,15 @@ def forward(
use_flash_attention = False

# Filter: cross attention + causal mask.
if (_flash_attn_2_1_plus
# (in training mode)
if (inference_params is None
and _flash_attn_2_1_plus
and "causal" in attn_mask_type
and max_seqlen_q != max_seqlen_kv):
and max_seqlen_q != max_seqlen_kv
):
warnings.warn(
"Disabling the use of FlashAttention since version 2.1+ has changed its behavior "
"for causal mask in cross attention. See "
"In training mode, disable the use of FlashAttention since version 2.1+ has "
"changed its behavior for causal mask in cross attention. See "
"https://github.com/Dao-AILab/flash-attention#21-change-behavior-of-causal-flag"
)
use_flash_attention = False
Expand All @@ -2753,7 +2799,11 @@ def forward(
if attn_mask_type == "arbitrary":
use_flash_attention = False
use_fused_attention = False
if "causal" in attn_mask_type and max_seqlen_q != max_seqlen_kv:

if (inference_params is None
and "causal" in attn_mask_type
and max_seqlen_q != max_seqlen_kv
):
use_unfused_attention = False

# Filter: bias.
Expand Down Expand Up @@ -3446,12 +3496,12 @@ def forward(
), f"core_attention_bias_type {core_attention_bias_type} is not supported!"

# =================================================
# Pre-allocate memory for key-values for inference.
# Pre-allocate memory for key-values for inference
# =================================================

if inference_params and self.layer_number is not None:
if self.layer_number not in inference_params.key_value_memory_dict:
inf_max_seq_len = inference_params.max_sequence_len
inf_max_seq_len = inference_params.max_sequence_length
inf_max_batch_size = inference_params.max_batch_size
inference_key_memory = self._allocate_memory(
inf_max_seq_len, inf_max_batch_size, hidden_states.dtype
Expand All @@ -3469,9 +3519,9 @@ def forward(
inference_value_memory,
) = inference_params.key_value_memory_dict[self.layer_number]

# =====================
# ======================
# Query, Key, and Value
# =====================
# ======================

if self.attention_type == "self":
# Attention heads [sq, b, h] --> [sq, b, ng * (np/ng + 2) * hn]
Expand Down Expand Up @@ -3593,51 +3643,37 @@ def forward(
)
query_layer = query_layer.view(*new_tensor_shape)

# ==================================
# Adjust key and value for inference
# ==================================
# ======================================================
# Apply relative positional encoding (rotary embedding)
# ======================================================

# duplicate the pos_emb for self attention
if rotary_pos_emb is not None:
# duplicate the pos_emb for self attention
if not isinstance(rotary_pos_emb, tuple):
rotary_pos_emb = ((rotary_pos_emb,) * 2)

if inference_params and self.layer_number is not None:
batch_start = inference_params.batch_size_offset
batch_end = batch_start + key_layer.size(1)
assert batch_end <= inference_key_memory.size(1)
sequence_start = inference_params.sequence_len_offset
sequence_end = sequence_start + key_layer.size(0)
assert sequence_end <= inference_key_memory.size(0)
# Copy key and values.
inference_key_memory[
sequence_start:sequence_end, batch_start:batch_end, ...
] = key_layer
inference_value_memory[
sequence_start:sequence_end, batch_start:batch_end, ...
] = value_layer
key_layer = inference_key_memory[:sequence_end, batch_start:batch_end, ...]
value_layer = inference_value_memory[
:sequence_end, batch_start:batch_end, ...
]

# adjust the key rotary positional embedding
if rotary_pos_emb is not None:
q_pos_emb, k_pos_emb = rotary_pos_emb
q_pos_emb = q_pos_emb[sequence_start:sequence_end, :, :, :]
k_pos_emb = k_pos_emb[:sequence_end, :, :, :]
rotary_pos_emb = (q_pos_emb, k_pos_emb)

# ==================================
# core attention computation
# ==================================

# apply relative positional encoding (rotary embedding)
if rotary_pos_emb is not None:
q_pos_emb, k_pos_emb = rotary_pos_emb

# adjust key and value for inference
if inference_params is not None:
if self.qkv_format == "sbhd":
sequence_length = key_layer.size(0)
elif self.qkv_format == "bshd":
sequence_length = key_layer.size(1)

sequence_start = inference_params.sequence_len_offset
sequence_end = sequence_start + sequence_length

q_pos_emb = q_pos_emb[sequence_start:sequence_end, ...]
k_pos_emb = k_pos_emb[sequence_start:sequence_end, ...]

query_layer = apply_rotary_pos_emb(query_layer, q_pos_emb, self.qkv_format, fused=True)
key_layer = apply_rotary_pos_emb(key_layer, k_pos_emb, self.qkv_format, fused=True)

# ===========================
# Core attention computation
# ===========================

context_layer = self.core_attention(
query_layer,
key_layer,
Expand All @@ -3653,11 +3689,12 @@ def forward(
core_attention_bias=core_attention_bias,
alibi_slopes=alibi_slopes,
fast_zero_fill=fast_zero_fill,
inference_params=inference_params,
)

# =================
# ===================
# Output. [sq, b, h]
# =================
# ===================

projection_output = self.proj(
context_layer, is_first_microbatch=is_first_microbatch
Expand Down
Loading

0 comments on commit b459ccc

Please sign in to comment.