Skip to content

Commit

Permalink
[Rotary] Add test for rotary when qkv are packed an there's GQA
Browse files Browse the repository at this point in the history
  • Loading branch information
tridao committed Sep 13, 2024
1 parent 8c20cfe commit cc1690d
Showing 1 changed file with 24 additions and 8 deletions.
32 changes: 24 additions & 8 deletions tests/test_rotary.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,13 +97,15 @@ def test_rotary_emb_func(inplace, interleaved, rotary_fraction, seqlen_offsets_t
"dtype", ([torch.float16] if not is_sm8x else [torch.float16, torch.bfloat16])
)
# @pytest.mark.parametrize('dtype', ([torch.float16]))
@pytest.mark.parametrize("gqa", [False, True])
# @pytest.mark.parametrize("gqa", [False])
@pytest.mark.parametrize("seqlen_offsets_type", [0, int, torch.Tensor])
# @pytest.mark.parametrize("seqlen_offsets_type", [0])
@pytest.mark.parametrize("rotary_fraction", [1.0, 0.5])
# @pytest.mark.parametrize('rotary_fraction', [1.0])
@pytest.mark.parametrize("interleaved", [False, True])
# @pytest.mark.parametrize('interleaved', [False])
def test_rotary_emb_qkv(interleaved, rotary_fraction, seqlen_offsets_type, dtype):
def test_rotary_emb_qkv(interleaved, rotary_fraction, seqlen_offsets_type, gqa, dtype):
rtol = 1e-3
batch_size = 32
nheads = 4
Expand All @@ -112,23 +114,37 @@ def test_rotary_emb_qkv(interleaved, rotary_fraction, seqlen_offsets_type, dtype
device = "cuda"
rotary_dim = int(rotary_fraction * headdim)
torch.manual_seed(42)
qkv = torch.randn(
batch_size, seqlen, 3, nheads, headdim, dtype=dtype, device=device, requires_grad=True
)
if not gqa:
qkv = torch.randn(
batch_size, seqlen, 3, nheads, headdim, dtype=dtype, device=device, requires_grad=True
)
else:
nheads_k = nheads // 2
qkv = torch.randn(
batch_size, seqlen, nheads + nheads_k * 2, headdim, dtype=dtype, device=device, requires_grad=True
)
qkv_pt = qkv.detach().clone().requires_grad_()
cos, sin = generate_cos_sin(seqlen, rotary_dim, device, dtype)
seqlen_offsets = generate_seqlen_offsets(seqlen_offsets_type, batch_size, seqlen, device)
out = apply_rotary_emb_qkv_(
qkv, cos, sin, seqlen_offsets=seqlen_offsets, interleaved=interleaved
qkv, cos, sin, seqlen_offsets=seqlen_offsets, interleaved=interleaved,
num_heads_q=None if not gqa else nheads
)
cos_pt, sin_pt = index_cos_sin(cos, sin, seqlen_offsets, seqlen)
if not gqa:
q_pt, k_pt, v_pt = qkv_pt.unbind(2)
else:
q_pt, k_pt, v_pt = qkv_pt.split([nheads, nheads_k, nheads_k], dim=2)
q_pt = apply_rotary_emb_torch(
qkv_pt[:, :, 0].float(), cos_pt.float(), sin_pt.float(), interleaved=interleaved
q_pt.float(), cos_pt.float(), sin_pt.float(), interleaved=interleaved
).to(dtype=dtype)
k_pt = apply_rotary_emb_torch(
qkv_pt[:, :, 1].float(), cos_pt.float(), sin_pt.float(), interleaved=interleaved
k_pt.float(), cos_pt.float(), sin_pt.float(), interleaved=interleaved
).to(dtype=dtype)
out_pt = torch.stack([q_pt, k_pt, qkv_pt[:, :, 2]], dim=2)
if not gqa:
out_pt = torch.stack([q_pt, k_pt, v_pt], dim=2)
else:
out_pt = torch.cat([q_pt, k_pt, v_pt], dim=2)
print(f"Output max diff: {(out - out_pt).abs().max().item()}")

g = torch.randn_like(out)
Expand Down

0 comments on commit cc1690d

Please sign in to comment.