|
| 1 | +import logging |
| 2 | +from typing import Optional |
| 3 | + |
| 4 | +import torch |
| 5 | +import triton |
| 6 | +import triton.language as tl |
| 7 | + |
| 8 | +import flag_gems |
| 9 | +from flag_gems.runtime import torch_device_fn |
| 10 | +from flag_gems.utils import libentry |
| 11 | +from flag_gems.utils import triton_lang_extension as tle |
| 12 | + |
| 13 | + |
| 14 | +@triton.jit |
| 15 | +def rotary_embedding_rw_kernel( |
| 16 | + state_out, |
| 17 | + state, |
| 18 | + cos, |
| 19 | + sin, |
| 20 | + stride_state_n, |
| 21 | + stride_state_h, |
| 22 | + stride_state_d, |
| 23 | + stride_cos_n, |
| 24 | + stride_cos_d, |
| 25 | + num_tokens, |
| 26 | + num_heads, |
| 27 | + token_range, |
| 28 | + head_range, |
| 29 | + dim_range_x, |
| 30 | + dim_range_y, |
| 31 | + rotary_interleaved: tl.constexpr, |
| 32 | +): |
| 33 | + state_x_offset = ( |
| 34 | + token_range[:, None, None] * stride_state_n |
| 35 | + + head_range[None, :, None] * stride_state_h |
| 36 | + + dim_range_x[None, None, :] * stride_state_d |
| 37 | + ) |
| 38 | + state_y_offset = ( |
| 39 | + token_range[:, None, None] * stride_state_n |
| 40 | + + head_range[None, :, None] * stride_state_h |
| 41 | + + dim_range_y[None, None, :] * stride_state_d |
| 42 | + ) |
| 43 | + |
| 44 | + cos_sim_offset = ( |
| 45 | + token_range[:, None, None] * stride_cos_n |
| 46 | + + dim_range_x[None, None, :] * stride_cos_d |
| 47 | + ) |
| 48 | + if rotary_interleaved: |
| 49 | + sin_sim_offset = ( |
| 50 | + token_range[:, None, None] * stride_cos_n |
| 51 | + + dim_range_y[None, None, :] * stride_cos_d |
| 52 | + ) |
| 53 | + else: |
| 54 | + sin_sim_offset = cos_sim_offset |
| 55 | + |
| 56 | + state_x = tl.load( |
| 57 | + state + state_x_offset, |
| 58 | + mask=(token_range[:, None, None] < num_tokens) |
| 59 | + & (head_range[None, :, None] < num_heads), |
| 60 | + other=0.0, |
| 61 | + ) |
| 62 | + state_y = tl.load( |
| 63 | + state + state_y_offset, |
| 64 | + mask=(token_range[:, None, None] < num_tokens) |
| 65 | + & (head_range[None, :, None] < num_heads), |
| 66 | + other=0.0, |
| 67 | + ) |
| 68 | + |
| 69 | + cos_loaded = tl.load( |
| 70 | + cos + cos_sim_offset, |
| 71 | + mask=token_range[:, None, None] < num_tokens, |
| 72 | + other=0.0, |
| 73 | + ).to(tl.float32) |
| 74 | + sin_loaded = tl.load( |
| 75 | + sin + sin_sim_offset, |
| 76 | + mask=token_range[:, None, None] < num_tokens, |
| 77 | + other=0.0, |
| 78 | + ).to(tl.float32) |
| 79 | + |
| 80 | + out_x = state_x * cos_loaded - state_y * sin_loaded |
| 81 | + out_y = state_x * sin_loaded + state_y * cos_loaded |
| 82 | + |
| 83 | + tl.store( |
| 84 | + state_out + state_x_offset, |
| 85 | + out_x, |
| 86 | + mask=(token_range[:, None, None] < num_tokens) |
| 87 | + & (head_range[None, :, None] < num_heads), |
| 88 | + ) |
| 89 | + tl.store( |
| 90 | + state_out + state_y_offset, |
| 91 | + out_y, |
| 92 | + mask=(token_range[:, None, None] < num_tokens) |
| 93 | + & (head_range[None, :, None] < num_heads), |
| 94 | + ) |
| 95 | + |
| 96 | + |
| 97 | +@libentry() |
| 98 | +@triton.jit |
| 99 | +def rotary_embedding_siso_kernel( |
| 100 | + state_out, # [num_tokens, head_num, head_dim] |
| 101 | + state, # [num_tokens, head_num, head_dim] |
| 102 | + cos, # [num_tokens, 1, head_dim // 2] |
| 103 | + sin, # [num_tokens, 1, head_dim // 2] |
| 104 | + stride_state_n, |
| 105 | + stride_state_h, |
| 106 | + stride_state_d, |
| 107 | + stride_cos_n, |
| 108 | + stride_cos_d, |
| 109 | + num_tokens, |
| 110 | + num_heads, |
| 111 | + BLOCK_N: tl.constexpr, |
| 112 | + BLOCK_H: tl.constexpr, |
| 113 | + BLOCK_D: tl.constexpr, |
| 114 | + rotary_interleaved: tl.constexpr, |
| 115 | +): |
| 116 | + token_index = tl.program_id(0) |
| 117 | + token_range = token_index * BLOCK_N + tl.arange(0, BLOCK_N) |
| 118 | + head_index = tl.program_id(1) |
| 119 | + head_range = head_index * BLOCK_H + tl.arange(0, BLOCK_H) |
| 120 | + |
| 121 | + if rotary_interleaved: |
| 122 | + for d in range(0, BLOCK_D // 2): |
| 123 | + dim_range_x = d * 2 |
| 124 | + dim_range_y = d * 2 + 1 |
| 125 | + |
| 126 | + rotary_embedding_rw_kernel( |
| 127 | + state_out, |
| 128 | + state, |
| 129 | + cos, |
| 130 | + sin, |
| 131 | + stride_state_n, |
| 132 | + stride_state_h, |
| 133 | + stride_state_d, |
| 134 | + stride_cos_n, |
| 135 | + stride_cos_d, |
| 136 | + num_tokens, |
| 137 | + num_heads, |
| 138 | + token_range, |
| 139 | + head_range, |
| 140 | + dim_range_x, |
| 141 | + dim_range_y, |
| 142 | + rotary_interleaved, |
| 143 | + ) |
| 144 | + else: |
| 145 | + dim_range_x = tl.arange(0, BLOCK_D // 2) |
| 146 | + dim_range_y = tl.arange(BLOCK_D // 2, BLOCK_D) |
| 147 | + rotary_embedding_rw_kernel( |
| 148 | + state_out, |
| 149 | + state, |
| 150 | + cos, |
| 151 | + sin, |
| 152 | + stride_state_n, |
| 153 | + stride_state_h, |
| 154 | + stride_state_d, |
| 155 | + stride_cos_n, |
| 156 | + stride_cos_d, |
| 157 | + num_tokens, |
| 158 | + num_heads, |
| 159 | + token_range, |
| 160 | + head_range, |
| 161 | + dim_range_x, |
| 162 | + dim_range_y, |
| 163 | + rotary_interleaved, |
| 164 | + ) |
| 165 | + |
| 166 | +def apply_rotary_pos_emb( |
| 167 | + q, |
| 168 | + k, |
| 169 | + cos, |
| 170 | + sin, |
| 171 | + position_ids: Optional[torch.IntTensor] = None, |
| 172 | + rotary_interleaved: bool = False, |
| 173 | +): |
| 174 | + """ |
| 175 | + Apply rotary position embedding to q and k |
| 176 | +
|
| 177 | + Args: |
| 178 | + q: (*, q_heads, head_dim) |
| 179 | + k: (*, k_heads, head_dim) |
| 180 | + cos: (max_seq_len, head_dim // 2) |
| 181 | + sin: (max_seq_len, head_dim // 2) |
| 182 | + position_ids: (*, ), optional, position ids for each token |
| 183 | + rotary_interleaved: whether the head_dim is rotated in an interleaved way |
| 184 | +
|
| 185 | + Returns: |
| 186 | + q_embed: (*, q_heads, head_dim) |
| 187 | + k_embed: (*, k_heads, head_dim) |
| 188 | + """ |
| 189 | + logging.debug("GEMS_ASCEND ROTARY POS EMBEDDING") |
| 190 | + assert ( |
| 191 | + k.shape[-1] == q.shape[-1] |
| 192 | + ), f"q and k must have the same last dimension, got {q.shape} and {k.shape}" |
| 193 | + assert ( |
| 194 | + cos.shape[-1] == sin.shape[-1] |
| 195 | + ), f"cos and sin must have the same last dimension, got {cos.shape} and {sin.shape}" |
| 196 | + assert ( |
| 197 | + cos.shape[-1] * 2 == q.shape[-1] |
| 198 | + ), f"cos/sin dim must be half of q/k dim, got {cos.shape} and {q.shape}" |
| 199 | + assert cos.stride(-1) == 1, "cos must be contiguous at the last dimension" |
| 200 | + assert sin.stride(-1) == 1, "sin must be contiguous at the last dimension" |
| 201 | + |
| 202 | + q_shape = q.shape |
| 203 | + k_shape = k.shape |
| 204 | + |
| 205 | + assert ( |
| 206 | + q.shape[:-2] == k.shape[:-2] |
| 207 | + ), f"q and k must have the same length, got {q.shape[:-2]} and {k.shape[:-2]}" |
| 208 | + if position_ids is None: |
| 209 | + assert ( |
| 210 | + len(q.shape) == 4 |
| 211 | + ), f"q must have 4 dimensions if position_ids is not provided, got {q.shape}" |
| 212 | + seq_len = q.shape[-3] |
| 213 | + else: |
| 214 | + assert ( |
| 215 | + position_ids.shape == q.shape[:-2] |
| 216 | + ), f"position_ids must have the same length as q, got {position_ids.shape} and {q.shape[:-2]}" |
| 217 | + |
| 218 | + position_ids = position_ids.view(-1) |
| 219 | + seq_len = None |
| 220 | + |
| 221 | + q = q.view(-1, q.shape[-2], q.shape[-1]) |
| 222 | + k = k.view(-1, k.shape[-2], k.shape[-1]) |
| 223 | + |
| 224 | + q_embed = torch.empty_like(q) |
| 225 | + k_embed = torch.empty_like(k) |
| 226 | + |
| 227 | + def torch_rotary_embedding(state_out, state, cos, sin): |
| 228 | + num_tokens = state.shape[0] |
| 229 | + num_heads = state.shape[1] |
| 230 | + head_dim = state.shape[-1] |
| 231 | + |
| 232 | + BLOCK_N = 8 |
| 233 | + BLOCK_H = 4 |
| 234 | + grid = ( |
| 235 | + triton.cdiv(num_tokens, BLOCK_N), |
| 236 | + triton.cdiv(num_heads, BLOCK_H), |
| 237 | + ) |
| 238 | + with torch_device_fn.device(state_out.device): |
| 239 | + with flag_gems.use_gems(): |
| 240 | + if position_ids is None: |
| 241 | + cos = cos[: q_shape[-3], None, :] |
| 242 | + sin = sin[: q_shape[-3], None, :] |
| 243 | + else: |
| 244 | + cos = cos[position_ids, None, :] |
| 245 | + sin = sin[position_ids, None, :] |
| 246 | + |
| 247 | + if rotary_interleaved: |
| 248 | + cos = torch.repeat_interleave(cos, 2, dim=-1) |
| 249 | + sin = torch.repeat_interleave(sin, 2, dim=-1) |
| 250 | + orig_cos = cos |
| 251 | + orig_sin = sin |
| 252 | + for _ in range(q_shape[0] - 1): |
| 253 | + cos = torch.cat((cos, orig_cos), dim=0) |
| 254 | + sin = torch.cat((sin, orig_sin), dim=0) |
| 255 | + rotary_embedding_siso_kernel[grid]( |
| 256 | + state_out, |
| 257 | + state, |
| 258 | + cos, |
| 259 | + sin, |
| 260 | + state.stride(0), |
| 261 | + state.stride(1), |
| 262 | + state.stride(2), |
| 263 | + cos.stride(0), |
| 264 | + cos.stride(2), |
| 265 | + num_tokens, |
| 266 | + num_heads, |
| 267 | + BLOCK_N=BLOCK_N, |
| 268 | + BLOCK_H=BLOCK_H, |
| 269 | + BLOCK_D=head_dim, |
| 270 | + rotary_interleaved=rotary_interleaved, |
| 271 | + ) |
| 272 | + |
| 273 | + torch_rotary_embedding(q_embed, q, cos, sin) |
| 274 | + torch_rotary_embedding(k_embed, k, cos, sin) |
| 275 | + |
| 276 | + q_embed = q_embed.view(q_shape) |
| 277 | + k_embed = k_embed.view(k_shape) |
| 278 | + return q_embed, k_embed |
0 commit comments