|
23 | 23 | _InputCacheState = Tuple[_CacheMap, _CacheMap] |
24 | 24 | _OutputCacheState = Tuple[_CacheMap, _CacheMap] |
25 | 25 |
|
| 26 | +# This fixes numerics on iOS26. Possibly disable in future, depending on bug fixes in Core ML runtime |
| 27 | +_DECOMPOSE_SDPA_IN_STATIC_ATTENTION_MHA: bool = True |
26 | 28 |
|
27 | 29 | def none_throws(x: Optional[Any]) -> Any: |
28 | 30 | assert x is not None |
@@ -1027,16 +1029,56 @@ def _forward_mha( |
1027 | 1029 | k, out_cache_state = self.k_caches[0].update(k, in_cache_state, out_cache_state) |
1028 | 1030 | v, out_cache_state = self.v_caches[0].update(v, in_cache_state, out_cache_state) |
1029 | 1031 |
|
1030 | | - if self.n_rep > 1: |
1031 | | - k = k.repeat_interleave(self.n_rep, dim=1) |
1032 | | - v = v.repeat_interleave(self.n_rep, dim=1) |
1033 | | - |
1034 | 1032 | mask = None |
1035 | 1033 | masks = kwargs.get("masks") |
1036 | 1034 | if masks: |
1037 | 1035 | cache_len = k.size(-2) - seq_len |
1038 | 1036 | mask = masks[cache_len] |
1039 | | - y = F.scaled_dot_product_attention(q, k, v, attn_mask=mask) |
| 1037 | + |
| 1038 | + if not _DECOMPOSE_SDPA_IN_STATIC_ATTENTION_MHA: |
| 1039 | + if self.n_rep > 1: |
| 1040 | + k = k.repeat_interleave(self.n_rep, dim=1) |
| 1041 | + v = v.repeat_interleave(self.n_rep, dim=1) |
| 1042 | + y = F.scaled_dot_product_attention(q, k, v, attn_mask=mask) |
| 1043 | + else: |
| 1044 | + # We remove bsz dim to keep matmul's on 4D tensors |
| 1045 | + # Core ML sometimes fails at runtime when given 5D tensors |
| 1046 | + assert bsz == 1, "Batch size > 1 not supported yet" |
| 1047 | + |
| 1048 | + n_kv = self.n_kv_heads |
| 1049 | + n_rep = self.n_rep |
| 1050 | + D = self.head_dim |
| 1051 | + |
| 1052 | + # Explicitly track lengths; they are NOT necessarily equal. |
| 1053 | + Tq = q.size(-2) # query length (current step/window), e.g. 64 |
| 1054 | + Tk = k.size(-2) # key/value length (cache length), e.g. 2048 |
| 1055 | + |
| 1056 | + # Group Q to match KV layout |
| 1057 | + # q: (bsz=1, n_heads, Tq, D), with n_heads = n_kv * n_rep |
| 1058 | + # 1 * n_heads * Tq * D == n_kv * n_rep * Tq * D |
| 1059 | + # q_grouped: (n_kv, n_rep, Tq, D) |
| 1060 | + q_grouped = q.view(n_kv, n_rep, Tq, D) |
| 1061 | + |
| 1062 | + # Prepare K for grouped KV matmul |
| 1063 | + # k: (1, n_kv, Tk, d) -> (n_kv, 1, Tk, D) |
| 1064 | + k_grouped = k.view(n_kv, 1, Tk, D) |
| 1065 | + |
| 1066 | + # (n_kv, n_rep, Tq, Tk) |
| 1067 | + attn_grouped = q_grouped @ k_grouped.transpose(-2, -1) |
| 1068 | + attn_grouped = attn_grouped * self.inv_scale |
| 1069 | + |
| 1070 | + # Ungroup, add mask, and regroup |
| 1071 | + attn_grouped = attn_grouped.view(1, self.n_heads, Tq, Tk) |
| 1072 | + attn_grouped = attn_grouped + mask |
| 1073 | + attn_grouped = F.softmax(attn_grouped, dim=-1) |
| 1074 | + attn_grouped = attn_grouped.view(n_kv, n_rep, Tq, Tk) |
| 1075 | + |
| 1076 | + # Group v |
| 1077 | + v_grouped = v.view(n_kv, 1, Tk, D) |
| 1078 | + y_grouped = attn_grouped @ v_grouped |
| 1079 | + |
| 1080 | + # Ungroup y |
| 1081 | + y = y_grouped.view(1, self.n_heads, Tq, D) |
1040 | 1082 |
|
1041 | 1083 | return y.transpose(1, 2).contiguous().view(bsz, seq_len, -1), out_cache_state |
1042 | 1084 |
|
|
0 commit comments