Skip to content

Commit d7dd45d

Browse files
metascroyfacebook-github-bot
authored andcommitted
Fix CoreML iOS26 numerics in attention (#16144)
Summary: This diff decomposes SDPA to fix iOS26 numerics in Core ML. It also removes repeat interleave to further optimize performance on Core ML by about 10-15%, depending on the hardware. Differential Revision: D88705980
1 parent 9156fff commit d7dd45d

File tree

1 file changed

+47
-5
lines changed

1 file changed

+47
-5
lines changed

examples/models/llama/static_attention.py

Lines changed: 47 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,8 @@
2323
_InputCacheState = Tuple[_CacheMap, _CacheMap]
2424
_OutputCacheState = Tuple[_CacheMap, _CacheMap]
2525

26+
# This fixes numerics on iOS26. Possibly disable in future, depending on bug fixes in Core ML runtime
27+
_DECOMPOSE_SDPA_IN_STATIC_ATTENTION_MHA: bool = True
2628

2729
def none_throws(x: Optional[Any]) -> Any:
2830
assert x is not None
@@ -1027,16 +1029,56 @@ def _forward_mha(
10271029
k, out_cache_state = self.k_caches[0].update(k, in_cache_state, out_cache_state)
10281030
v, out_cache_state = self.v_caches[0].update(v, in_cache_state, out_cache_state)
10291031

1030-
if self.n_rep > 1:
1031-
k = k.repeat_interleave(self.n_rep, dim=1)
1032-
v = v.repeat_interleave(self.n_rep, dim=1)
1033-
10341032
mask = None
10351033
masks = kwargs.get("masks")
10361034
if masks:
10371035
cache_len = k.size(-2) - seq_len
10381036
mask = masks[cache_len]
1039-
y = F.scaled_dot_product_attention(q, k, v, attn_mask=mask)
1037+
1038+
if not _DECOMPOSE_SDPA_IN_STATIC_ATTENTION_MHA:
1039+
if self.n_rep > 1:
1040+
k = k.repeat_interleave(self.n_rep, dim=1)
1041+
v = v.repeat_interleave(self.n_rep, dim=1)
1042+
y = F.scaled_dot_product_attention(q, k, v, attn_mask=mask)
1043+
else:
1044+
# We remove bsz dim to keep matmul's on 4D tensors
1045+
# Core ML sometimes fails at runtime when given 5D tensors
1046+
assert bsz == 1, "Batch size > 1 not supported yet"
1047+
1048+
n_kv = self.n_kv_heads
1049+
n_rep = self.n_rep
1050+
D = self.head_dim
1051+
1052+
# Explicitly track lengths; they are NOT necessarily equal.
1053+
Tq = q.size(-2) # query length (current step/window), e.g. 64
1054+
Tk = k.size(-2) # key/value length (cache length), e.g. 2048
1055+
1056+
# Group Q to match KV layout
1057+
# q: (bsz=1, n_heads, Tq, D), with n_heads = n_kv * n_rep
1058+
# 1 * n_heads * Tq * D == n_kv * n_rep * Tq * D
1059+
# q_grouped: (n_kv, n_rep, Tq, D)
1060+
q_grouped = q.view(n_kv, n_rep, Tq, D)
1061+
1062+
# Prepare K for grouped KV matmul
1063+
# k: (1, n_kv, Tk, d) -> (n_kv, 1, Tk, D)
1064+
k_grouped = k.view(n_kv, 1, Tk, D)
1065+
1066+
# (n_kv, n_rep, Tq, Tk)
1067+
attn_grouped = q_grouped @ k_grouped.transpose(-2, -1)
1068+
attn_grouped = attn_grouped * self.inv_scale
1069+
1070+
# Ungroup, add mask, and regroup
1071+
attn_grouped = attn_grouped.view(1, self.n_heads, Tq, Tk)
1072+
attn_grouped = attn_grouped + mask
1073+
attn_grouped = F.softmax(attn_grouped, dim=-1)
1074+
attn_grouped = attn_grouped.view(n_kv, n_rep, Tq, Tk)
1075+
1076+
# Group v
1077+
v_grouped = v.view(n_kv, 1, Tk, D)
1078+
y_grouped = attn_grouped @ v_grouped
1079+
1080+
# Ungroup y
1081+
y = y_grouped.view(1, self.n_heads, Tq, D)
10401082

10411083
return y.transpose(1, 2).contiguous().view(bsz, seq_len, -1), out_cache_state
10421084

0 commit comments

Comments
 (0)