Skip to content

Commit

Permalink
using sdpa if available
Browse files Browse the repository at this point in the history
  • Loading branch information
jongwook committed Sep 30, 2024
1 parent 423492d commit 8b5b497
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 11 deletions.
51 changes: 41 additions & 10 deletions whisper/model.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import base64
import gzip
from contextlib import contextmanager
from dataclasses import dataclass
from typing import Dict, Iterable, Optional

Expand All @@ -12,6 +13,14 @@
from .decoding import detect_language as detect_language_function
from .transcribe import transcribe as transcribe_function

try:
from torch.nn.functional import scaled_dot_product_attention

SDPA_AVAILABLE = True
except (ImportError, RuntimeError, OSError):
scaled_dot_product_attention = None
SDPA_AVAILABLE = False


@dataclass
class ModelDimensions:
Expand Down Expand Up @@ -59,7 +68,19 @@ def sinusoids(length, channels, max_timescale=10000):
return torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], dim=1)


@contextmanager
def disable_sdpa():
prev_state = MultiHeadAttention.use_sdpa
try:
MultiHeadAttention.use_sdpa = False
yield
finally:
MultiHeadAttention.use_sdpa = prev_state


class MultiHeadAttention(nn.Module):
use_sdpa = True

def __init__(self, n_state: int, n_head: int):
super().__init__()
self.n_head = n_head
Expand Down Expand Up @@ -91,21 +112,31 @@ def forward(
return self.out(wv), qk

def qkv_attention(
self, q: Tensor, k: Tensor, v: Tensor, mask: Optional[Tensor] = None
):
self, q: Tensor, k: Tensor, v: Tensor, mask: Tensor | None = None
) -> tuple[torch.Tensor, torch.Tensor | None]:
n_batch, n_ctx, n_state = q.shape
scale = (n_state // self.n_head) ** -0.25
q = q.view(*q.shape[:2], self.n_head, -1).permute(0, 2, 1, 3) * scale
k = k.view(*k.shape[:2], self.n_head, -1).permute(0, 2, 3, 1) * scale
q = q.view(*q.shape[:2], self.n_head, -1).permute(0, 2, 1, 3)
k = k.view(*k.shape[:2], self.n_head, -1).permute(0, 2, 1, 3)
v = v.view(*v.shape[:2], self.n_head, -1).permute(0, 2, 1, 3)

qk = q @ k
if mask is not None:
qk = qk + mask[:n_ctx, :n_ctx]
qk = qk.float()
if SDPA_AVAILABLE and MultiHeadAttention.use_sdpa:
a = scaled_dot_product_attention(
q, k, v, is_causal=mask is not None and n_ctx > 1
)
out = a.permute(0, 2, 1, 3).flatten(start_dim=2)
qk = None
else:
qk = (q * scale) @ (k * scale).transpose(-1, -2)
if mask is not None:
qk = qk + mask[:n_ctx, :n_ctx]
qk = qk.float()

w = F.softmax(qk, dim=-1).to(q.dtype)
out = (w @ v).permute(0, 2, 1, 3).flatten(start_dim=2)
qk = qk.detach()

w = F.softmax(qk, dim=-1).to(q.dtype)
return (w @ v).permute(0, 2, 1, 3).flatten(start_dim=2), qk.detach()
return out, qk


class ResidualAttentionBlock(nn.Module):
Expand Down
4 changes: 3 additions & 1 deletion whisper/timing.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,9 @@ def find_alignment(
for i, block in enumerate(model.decoder.blocks)
]

with torch.no_grad():
from .model import disable_sdpa

with torch.no_grad(), disable_sdpa():
logits = model(mel.unsqueeze(0), tokens.unsqueeze(0))[0]
sampled_logits = logits[len(tokenizer.sot_sequence) :, : tokenizer.eot]
token_probs = sampled_logits.softmax(dim=-1)
Expand Down

0 comments on commit 8b5b497

Please sign in to comment.