diff --git a/problems/amd/mla/reference.py b/problems/amd/mla/reference.py new file mode 100644 index 0000000..d67b4b0 --- /dev/null +++ b/problems/amd/mla/reference.py @@ -0,0 +1,405 @@ +import torch +import torch.nn as nn +import math +import random +from task import input_t, output_t +from utils import make_match_reference +from typing import Optional, Tuple, Union + + +# Adapted from https://raw.githubusercontent.com/vllm-project/vllm/refs/tags/v0.6.6.post1/vllm/model_executor/layers/rotary_embedding.py +## We provide the implementation of the rotary embedding here, you do not need to modify this section +def _rotate_neox(x: torch.Tensor) -> torch.Tensor: + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + + return torch.cat((-x2, x1), dim=-1) + + +def _rotate_gptj(x: torch.Tensor) -> torch.Tensor: + x1 = x[..., ::2] + x2 = x[..., 1::2] + x = torch.stack((-x2, x1), dim=-1) + return x.flatten(-2) + + +def _apply_rotary_emb( + x: torch.Tensor, + cos: torch.Tensor, + sin: torch.Tensor, + is_neox_style: bool, +) -> torch.Tensor: + """ + Args: + x: [num_tokens, num_heads, head_size] + cos: [num_tokens, head_size // 2] + sin: [num_tokens, head_size // 2] + is_neox_style: Whether to use the Neox-style or GPT-J-style rotary + positional embeddings. + """ + cos = cos.unsqueeze(-2).to(x.dtype) + sin = sin.unsqueeze(-2).to(x.dtype) + if is_neox_style: + x1, x2 = torch.chunk(x, 2, dim=-1) + else: + x1 = x[..., ::2] + x2 = x[..., 1::2] + o1 = x1 * cos - x2 * sin + o2 = x2 * cos + x1 * sin + if is_neox_style: + return torch.cat((o1, o2), dim=-1) + else: + return torch.stack((o1, o2), dim=-1).flatten(-2) + + +class RotaryEmbedding(nn.Module): + """Original rotary positional embedding.""" + + def __init__( + self, + head_size: int, + rotary_dim: int, + max_position_embeddings: int, + base: int, + is_neox_style: bool, + dtype: torch.dtype, + ) -> None: + super().__init__() + self.head_size = head_size + self.rotary_dim = rotary_dim + self.max_position_embeddings = max_position_embeddings + self.base = base + self.is_neox_style = is_neox_style + self.dtype = dtype + + cache = self._compute_cos_sin_cache() + cache = cache.to(dtype) + self.cos_sin_cache: torch.Tensor + self.register_buffer("cos_sin_cache", cache, persistent=False) + + def _compute_inv_freq(self, base: Union[int, float]) -> torch.Tensor: + """Compute the inverse frequency.""" + # NOTE(woosuk): To exactly match the HF implementation, we need to + # use CPU to compute the cache and then move it to GPU. However, we + # create the cache on GPU for faster initialization. This may cause + # a slight numerical difference between the HF implementation and ours. + inv_freq = 1.0 / ( + base + ** ( + torch.arange(0, self.rotary_dim, 2, dtype=torch.float) / self.rotary_dim + ) + ) + return inv_freq + + def _compute_cos_sin_cache(self) -> torch.Tensor: + """Compute the cos and sin cache.""" + inv_freq = self._compute_inv_freq(self.base) + t = torch.arange(self.max_position_embeddings, dtype=torch.float) + + freqs = torch.einsum("i,j -> ij", t, inv_freq) + cos = freqs.cos() + sin = freqs.sin() + cache = torch.cat((cos, sin), dim=-1) + return cache + + def forward_native( + self, + positions: torch.Tensor, + query: torch.Tensor, + key: torch.Tensor, + offsets: Optional[torch.Tensor] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """A PyTorch-native implementation of forward().""" + if offsets is not None: + positions = positions + offsets + positions = positions.flatten() + num_tokens = positions.shape[0] + cos_sin = self.cos_sin_cache.index_select(0, positions) + cos, sin = cos_sin.chunk(2, dim=-1) + + query_shape = query.shape + query = query.view(num_tokens, -1, self.head_size) + query_rot = query[..., : self.rotary_dim] + query_pass = query[..., self.rotary_dim :] + query_rot = _apply_rotary_emb(query_rot, cos, sin, self.is_neox_style) + query = torch.cat((query_rot, query_pass), dim=-1).reshape(query_shape) + + key_shape = key.shape + key = key.view(num_tokens, -1, self.head_size) + key_rot = key[..., : self.rotary_dim] + key_pass = key[..., self.rotary_dim :] + key_rot = _apply_rotary_emb(key_rot, cos, sin, self.is_neox_style) + key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape) + return query, key + + def extra_repr(self) -> str: + s = f"head_size={self.head_size}, rotary_dim={self.rotary_dim}" + s += f", max_position_embeddings={self.max_position_embeddings}" + s += f", base={self.base}, is_neox_style={self.is_neox_style}" + return s + + +# Inverse dim formula to find dim based on number of rotations +def _yarn_find_correction_dim( + num_rotations: int, + dim: int, + base: float = 10000, + max_position_embeddings: int = 2048, +) -> float: + return (dim * math.log(max_position_embeddings / (num_rotations * 2 * math.pi))) / ( + 2 * math.log(base) + ) + + +# Find dim range bounds based on rotations +def _yarn_find_correction_range( + low_rot: int, + high_rot: int, + dim: int, + base: float = 10000, + max_position_embeddings: int = 2048, +) -> Tuple[int, int]: + low = math.floor( + _yarn_find_correction_dim(low_rot, dim, base, max_position_embeddings) + ) + high = math.ceil( + _yarn_find_correction_dim(high_rot, dim, base, max_position_embeddings) + ) + return max(low, 0), min(high, dim - 1) # Clamp values just in case + + +def _yarn_linear_ramp_mask( + low: float, high: float, dim: int, dtype: torch.dtype, device +) -> torch.Tensor: + if low == high: + high += 0.001 # Prevent singularity + + linear_func = (torch.arange(dim, dtype=dtype, device=device) - low) / (high - low) + ramp_func = torch.clamp(linear_func, 0, 1) + return ramp_func + + +def yarn_get_mscale(scale: float = 1, mscale: float = 1) -> float: + if scale <= 1: + return 1.0 + return 0.1 * mscale * math.log(scale) + 1.0 + + +class DeepseekScalingRotaryEmbedding(RotaryEmbedding): + """RotaryEmbedding extended with YaRN method. + + Credits to Peng et al. github.com/jquesnelle/yarn + """ + + def __init__( + self, + head_size: int, + rotary_dim: int, + max_position_embeddings: int, + base: int, + is_neox_style: bool, + scaling_factor: float, + dtype: torch.dtype, + *, + extrapolation_factor: float = 1, + attn_factor: float = 1, + beta_fast: int = 32, + beta_slow: int = 1, + mscale: float = 1, + mscale_all_dim: float = 0, + device: Optional[str] = "cuda", + ) -> None: + self.scaling_factor = scaling_factor + self.extrapolation_factor = extrapolation_factor + self.attn_factor = attn_factor + self.beta_fast = beta_fast + self.beta_slow = beta_slow + # Get n-d magnitude scaling corrected for interpolation. + self.mscale = float( + yarn_get_mscale(self.scaling_factor, float(mscale)) + / yarn_get_mscale(self.scaling_factor, float(mscale_all_dim)) + * attn_factor + ) + self.device = device + super().__init__( + head_size, rotary_dim, max_position_embeddings, base, is_neox_style, dtype + ) + + def _compute_inv_freq(self, scaling_factor: float) -> torch.Tensor: + pos_freqs = self.base ** ( + torch.arange(0, self.rotary_dim, 2, dtype=torch.float, device=self.device) + / self.rotary_dim + ) + inv_freq_extrapolation = 1.0 / pos_freqs + inv_freq_interpolation = 1.0 / (scaling_factor * pos_freqs) + + low, high = _yarn_find_correction_range( + self.beta_fast, + self.beta_slow, + self.rotary_dim, + self.base, + self.max_position_embeddings, + ) + # Get n-d rotational scaling corrected for extrapolation + inv_freq_mask = ( + 1 + - _yarn_linear_ramp_mask( + low, high, self.rotary_dim // 2, dtype=torch.float, device=self.device + ) + ) * self.extrapolation_factor + inv_freq = ( + inv_freq_interpolation * (1 - inv_freq_mask) + + inv_freq_extrapolation * inv_freq_mask + ) + return inv_freq + + def _compute_cos_sin_cache(self) -> torch.Tensor: + inv_freq = self._compute_inv_freq(self.scaling_factor) + t = torch.arange( + self.max_position_embeddings * self.scaling_factor, + device=self.device, + dtype=torch.float32, + ) + freqs = torch.einsum("i,j -> ij", t, inv_freq) + cos = freqs.cos() * self.mscale + sin = freqs.sin() * self.mscale + cache = torch.cat((cos, sin), dim=-1) + + return cache + + def forward( + self, + positions: torch.Tensor, + query: torch.Tensor, + key: torch.Tensor, + offsets: Optional[torch.Tensor] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """PyTorch-native implementation equivalent to forward().""" + query_rot = query[..., : self.rotary_dim] + key_rot = key[..., : self.rotary_dim] + if self.rotary_dim < self.head_size: + query_pass = query[..., self.rotary_dim :] + key_pass = key[..., self.rotary_dim :] + + self.cos_sin_cache: torch.Tensor = self.cos_sin_cache.to(positions.device) + cos_sin = self.cos_sin_cache[ + torch.add(positions, offsets) if offsets is not None else positions + ] + # (max_seq, 64). 32 sin, 32 cos + cos, sin = cos_sin.chunk(2, dim=-1) + + if self.is_neox_style: + # NOTE(woosuk): Here we assume that the positions tensor has the + # shape [batch_size, seq_len]. + cos = cos.repeat(1, 1, 2).unsqueeze(-2) + sin = sin.repeat(1, 1, 2).unsqueeze(-2) + else: + cos = cos.repeat_interleave(2, dim=-1).unsqueeze(-2) + sin = sin.repeat_interleave(2, dim=-1).unsqueeze(-2) + + rotate_fn = _rotate_neox if self.is_neox_style else _rotate_gptj + query_rot = query_rot * cos + rotate_fn(query_rot) * sin + key_rot = key_rot * cos + rotate_fn(key_rot) * sin + + if self.rotary_dim < self.head_size: + query = torch.cat((query_rot, query_pass), dim=-1) + key = torch.cat((key_rot, key_pass), dim=-1) + else: + query = query_rot + key = key_rot + return query, key +## End of the implementation of the rotary embedding + + +def generate_input(b, d, dv, hq, sq, hkv, meansk, seed): + + cache_seqlens = torch.full((b,), meansk, dtype=torch.int32) + max_seqlen = cache_seqlens.max().item() + max_seqlen_pad = math.ceil(max_seqlen/256) * 256 + + gen = torch.Generator() + gen.manual_seed(seed) + + q = torch.randn((b, sq, hq, d), dtype=torch.bfloat16, generator=gen) + k = torch.randn((b, max_seqlen_pad, hkv, d), dtype=torch.bfloat16, generator=gen) + v = torch.randn((b, max_seqlen_pad, hkv, dv), dtype=torch.bfloat16, generator=gen) + positions = torch.tensor([sq], device=q.device).unsqueeze(0).repeat(b, 1) # only gen 1 token per req + return q, k, v, cache_seqlens, max_seqlen_pad, positions + +def scaled_dot_product_attention(query, key, value, hq, hkv, is_causal=False): + key = key.repeat_interleave(hq // hkv, dim=0) + value = value.repeat_interleave(hq // hkv, dim=0) + + scale = 1.0 / math.sqrt(query.size(-1)) + attn_weight = torch.matmul(query, key.transpose(-2, -1)) * scale + + # Apply causal mask if needed + if is_causal: + sq = query.shape[-2] + sk = key.shape[-2] + attn_bias = torch.zeros(sq, sk, dtype=torch.float32, device=query.device) + temp_mask = torch.ones(sq, sk, dtype=torch.bool, device=query.device).tril(diagonal=sk - sq) + attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf")) + attn_weight += attn_bias + + attn_weight = torch.nn.functional.softmax(attn_weight, dim=-1) + out = torch.matmul(attn_weight, value) + + return out + +def ref_kernel(data: input_t, use_rope=False) -> output_t: + """ + q shape: batch_size, q_seqlen, hq, d + k shape: batch_size, max_seqlen_pad, hkv, d + v shape: batch_size, max_seqlen_pad, hkv, d_v + cache_seqlens: tensor containing the actual sequence lengths + max_seqlen_pad: the padded sequence length + positions: tensor containing position information for RoPE + """ + q, k, v, cache_seqlens, max_seqlen_pad, positions = data + causal = False + b, sq, hq, d = q.shape + _, _, hkv, dv = v.shape + rope_head_dim = d - dv + rotary_dim = rope_head_dim + rope_max_seq_len=16324 + rope_base=1000 + rope_scaling=1.0 + is_neox_style=True + rotary_emb = DeepseekScalingRotaryEmbedding( + rope_head_dim, + rotary_dim, + rope_max_seq_len, + rope_base, + is_neox_style, + rope_scaling, + q.dtype, + device=q.device) + out = torch.empty(b, sq, hq, dv, dtype=q.dtype) + for i in range(b): + begin = i * max_seqlen_pad + end = begin + cache_seqlens[i] + ik = k.view(-1, hkv, d)[begin:end] + iv = v.view(-1, hkv, dv)[begin:end] + iq = q[i] + if use_rope: + q_nope, q_pe = iq.split([dv, rotary_dim], dim=-1) # [sq, hq, d] + k_nope, k_pe = ik.split([dv, rotary_dim], dim=-1) # [sk, hkv, d] + q_pe, k_pe = rotary_emb(positions[i], q_pe, k_pe) + iq[..., dv:]=q_pe + ik[..., dv:]=k_pe + O = scaled_dot_product_attention( + iq.transpose(0, 1), + ik.transpose(0, 1), + iv.transpose(0, 1), + hq=hq, + hkv=hkv, + is_causal=causal, + ) + out[i] = O.transpose(0, 1) + return out + + + + +check_implementation = make_match_reference(ref_kernel) \ No newline at end of file diff --git a/problems/amd/mla/submission.py b/problems/amd/mla/submission.py new file mode 100644 index 0000000..6f280bc --- /dev/null +++ b/problems/amd/mla/submission.py @@ -0,0 +1,381 @@ +import torch +import torch.nn as nn +from task import input_t, output_t +import math +from typing import Optional, Tuple, Union + +# Adapted from https://raw.githubusercontent.com/vllm-project/vllm/refs/tags/v0.6.6.post1/vllm/model_executor/layers/rotary_embedding.py +## We provide the implementation of the rotary embedding here, you do not need to modify this section +def _rotate_neox(x: torch.Tensor) -> torch.Tensor: + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + + return torch.cat((-x2, x1), dim=-1) + + +def _rotate_gptj(x: torch.Tensor) -> torch.Tensor: + x1 = x[..., ::2] + x2 = x[..., 1::2] + x = torch.stack((-x2, x1), dim=-1) + return x.flatten(-2) + + +def _apply_rotary_emb( + x: torch.Tensor, + cos: torch.Tensor, + sin: torch.Tensor, + is_neox_style: bool, +) -> torch.Tensor: + """ + Args: + x: [num_tokens, num_heads, head_size] + cos: [num_tokens, head_size // 2] + sin: [num_tokens, head_size // 2] + is_neox_style: Whether to use the Neox-style or GPT-J-style rotary + positional embeddings. + """ + cos = cos.unsqueeze(-2).to(x.dtype) + sin = sin.unsqueeze(-2).to(x.dtype) + if is_neox_style: + x1, x2 = torch.chunk(x, 2, dim=-1) + else: + x1 = x[..., ::2] + x2 = x[..., 1::2] + o1 = x1 * cos - x2 * sin + o2 = x2 * cos + x1 * sin + if is_neox_style: + return torch.cat((o1, o2), dim=-1) + else: + return torch.stack((o1, o2), dim=-1).flatten(-2) + + +class RotaryEmbedding(nn.Module): + """Original rotary positional embedding.""" + + def __init__( + self, + head_size: int, + rotary_dim: int, + max_position_embeddings: int, + base: int, + is_neox_style: bool, + dtype: torch.dtype, + ) -> None: + super().__init__() + self.head_size = head_size + self.rotary_dim = rotary_dim + self.max_position_embeddings = max_position_embeddings + self.base = base + self.is_neox_style = is_neox_style + self.dtype = dtype + + cache = self._compute_cos_sin_cache() + cache = cache.to(dtype) + self.cos_sin_cache: torch.Tensor + self.register_buffer("cos_sin_cache", cache, persistent=False) + + def _compute_inv_freq(self, base: Union[int, float]) -> torch.Tensor: + """Compute the inverse frequency.""" + # NOTE(woosuk): To exactly match the HF implementation, we need to + # use CPU to compute the cache and then move it to GPU. However, we + # create the cache on GPU for faster initialization. This may cause + # a slight numerical difference between the HF implementation and ours. + inv_freq = 1.0 / ( + base + ** ( + torch.arange(0, self.rotary_dim, 2, dtype=torch.float) / self.rotary_dim + ) + ) + return inv_freq + + def _compute_cos_sin_cache(self) -> torch.Tensor: + """Compute the cos and sin cache.""" + inv_freq = self._compute_inv_freq(self.base) + t = torch.arange(self.max_position_embeddings, dtype=torch.float) + + freqs = torch.einsum("i,j -> ij", t, inv_freq) + cos = freqs.cos() + sin = freqs.sin() + cache = torch.cat((cos, sin), dim=-1) + return cache + + def forward_native( + self, + positions: torch.Tensor, + query: torch.Tensor, + key: torch.Tensor, + offsets: Optional[torch.Tensor] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """A PyTorch-native implementation of forward().""" + if offsets is not None: + positions = positions + offsets + positions = positions.flatten() + num_tokens = positions.shape[0] + cos_sin = self.cos_sin_cache.index_select(0, positions) + cos, sin = cos_sin.chunk(2, dim=-1) + + query_shape = query.shape + query = query.view(num_tokens, -1, self.head_size) + query_rot = query[..., : self.rotary_dim] + query_pass = query[..., self.rotary_dim :] + query_rot = _apply_rotary_emb(query_rot, cos, sin, self.is_neox_style) + query = torch.cat((query_rot, query_pass), dim=-1).reshape(query_shape) + + key_shape = key.shape + key = key.view(num_tokens, -1, self.head_size) + key_rot = key[..., : self.rotary_dim] + key_pass = key[..., self.rotary_dim :] + key_rot = _apply_rotary_emb(key_rot, cos, sin, self.is_neox_style) + key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape) + return query, key + + def extra_repr(self) -> str: + s = f"head_size={self.head_size}, rotary_dim={self.rotary_dim}" + s += f", max_position_embeddings={self.max_position_embeddings}" + s += f", base={self.base}, is_neox_style={self.is_neox_style}" + return s + + +# Inverse dim formula to find dim based on number of rotations +def _yarn_find_correction_dim( + num_rotations: int, + dim: int, + base: float = 10000, + max_position_embeddings: int = 2048, +) -> float: + return (dim * math.log(max_position_embeddings / (num_rotations * 2 * math.pi))) / ( + 2 * math.log(base) + ) + + +# Find dim range bounds based on rotations +def _yarn_find_correction_range( + low_rot: int, + high_rot: int, + dim: int, + base: float = 10000, + max_position_embeddings: int = 2048, +) -> Tuple[int, int]: + low = math.floor( + _yarn_find_correction_dim(low_rot, dim, base, max_position_embeddings) + ) + high = math.ceil( + _yarn_find_correction_dim(high_rot, dim, base, max_position_embeddings) + ) + return max(low, 0), min(high, dim - 1) # Clamp values just in case + + +def _yarn_linear_ramp_mask( + low: float, high: float, dim: int, dtype: torch.dtype, device +) -> torch.Tensor: + if low == high: + high += 0.001 # Prevent singularity + + linear_func = (torch.arange(dim, dtype=dtype, device=device) - low) / (high - low) + ramp_func = torch.clamp(linear_func, 0, 1) + return ramp_func + + +def yarn_get_mscale(scale: float = 1, mscale: float = 1) -> float: + if scale <= 1: + return 1.0 + return 0.1 * mscale * math.log(scale) + 1.0 + + +class DeepseekScalingRotaryEmbedding(RotaryEmbedding): + """RotaryEmbedding extended with YaRN method. + + Credits to Peng et al. github.com/jquesnelle/yarn + """ + + def __init__( + self, + head_size: int, + rotary_dim: int, + max_position_embeddings: int, + base: int, + is_neox_style: bool, + scaling_factor: float, + dtype: torch.dtype, + *, + extrapolation_factor: float = 1, + attn_factor: float = 1, + beta_fast: int = 32, + beta_slow: int = 1, + mscale: float = 1, + mscale_all_dim: float = 0, + device: Optional[str] = "cuda", + ) -> None: + self.scaling_factor = scaling_factor + self.extrapolation_factor = extrapolation_factor + self.attn_factor = attn_factor + self.beta_fast = beta_fast + self.beta_slow = beta_slow + # Get n-d magnitude scaling corrected for interpolation. + self.mscale = float( + yarn_get_mscale(self.scaling_factor, float(mscale)) + / yarn_get_mscale(self.scaling_factor, float(mscale_all_dim)) + * attn_factor + ) + self.device = device + super().__init__( + head_size, rotary_dim, max_position_embeddings, base, is_neox_style, dtype + ) + + def _compute_inv_freq(self, scaling_factor: float) -> torch.Tensor: + pos_freqs = self.base ** ( + torch.arange(0, self.rotary_dim, 2, dtype=torch.float, device=self.device) + / self.rotary_dim + ) + inv_freq_extrapolation = 1.0 / pos_freqs + inv_freq_interpolation = 1.0 / (scaling_factor * pos_freqs) + + low, high = _yarn_find_correction_range( + self.beta_fast, + self.beta_slow, + self.rotary_dim, + self.base, + self.max_position_embeddings, + ) + # Get n-d rotational scaling corrected for extrapolation + inv_freq_mask = ( + 1 + - _yarn_linear_ramp_mask( + low, high, self.rotary_dim // 2, dtype=torch.float, device=self.device + ) + ) * self.extrapolation_factor + inv_freq = ( + inv_freq_interpolation * (1 - inv_freq_mask) + + inv_freq_extrapolation * inv_freq_mask + ) + return inv_freq + + def _compute_cos_sin_cache(self) -> torch.Tensor: + inv_freq = self._compute_inv_freq(self.scaling_factor) + t = torch.arange( + self.max_position_embeddings * self.scaling_factor, + device=self.device, + dtype=torch.float32, + ) + freqs = torch.einsum("i,j -> ij", t, inv_freq) + cos = freqs.cos() * self.mscale + sin = freqs.sin() * self.mscale + cache = torch.cat((cos, sin), dim=-1) + + return cache + + def forward( + self, + positions: torch.Tensor, + query: torch.Tensor, + key: torch.Tensor, + offsets: Optional[torch.Tensor] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """PyTorch-native implementation equivalent to forward().""" + query_rot = query[..., : self.rotary_dim] + key_rot = key[..., : self.rotary_dim] + if self.rotary_dim < self.head_size: + query_pass = query[..., self.rotary_dim :] + key_pass = key[..., self.rotary_dim :] + + self.cos_sin_cache: torch.Tensor = self.cos_sin_cache.to(positions.device) + cos_sin = self.cos_sin_cache[ + torch.add(positions, offsets) if offsets is not None else positions + ] + # (max_seq, 64). 32 sin, 32 cos + cos, sin = cos_sin.chunk(2, dim=-1) + + if self.is_neox_style: + # NOTE(woosuk): Here we assume that the positions tensor has the + # shape [batch_size, seq_len]. + cos = cos.repeat(1, 1, 2).unsqueeze(-2) + sin = sin.repeat(1, 1, 2).unsqueeze(-2) + else: + cos = cos.repeat_interleave(2, dim=-1).unsqueeze(-2) + sin = sin.repeat_interleave(2, dim=-1).unsqueeze(-2) + + rotate_fn = _rotate_neox if self.is_neox_style else _rotate_gptj + query_rot = query_rot * cos + rotate_fn(query_rot) * sin + key_rot = key_rot * cos + rotate_fn(key_rot) * sin + + if self.rotary_dim < self.head_size: + query = torch.cat((query_rot, query_pass), dim=-1) + key = torch.cat((key_rot, key_pass), dim=-1) + else: + query = query_rot + key = key_rot + return query, key +## End of the implementation of the rotary embedding + +def scaled_dot_product_attention(query, key, value, hq, hkv, is_causal=False): + key = key.repeat_interleave(hq // hkv, dim=0) + value = value.repeat_interleave(hq // hkv, dim=0) + + scale = 1.0 / math.sqrt(query.size(-1)) + attn_weight = torch.matmul(query, key.transpose(-2, -1)) * scale + + # Apply causal mask if needed + if is_causal: + sq = query.shape[-2] + sk = key.shape[-2] + attn_bias = torch.zeros(sq, sk, dtype=torch.float32, device=query.device) + temp_mask = torch.ones(sq, sk, dtype=torch.bool, device=query.device).tril(diagonal=sk - sq) + attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf")) + attn_weight += attn_bias + + attn_weight = torch.nn.functional.softmax(attn_weight, dim=-1) + out = torch.matmul(attn_weight, value) + + return out + +def custom_kernel(data: input_t, use_rope=False) -> output_t: + """ + Reference implementation of mla without RoPE + Args: + data: q, k, v, cache_seqlens, max_seqlen_pad, positions + Returns: + mla output + """ + + q, k, v, cache_seqlens, max_seqlen_pad, positions = data + causal = False + b, sq, hq, d = q.shape + _, _, hkv, dv = v.shape + rope_head_dim = d - dv + rotary_dim = rope_head_dim + rope_max_seq_len=16324 + rope_base=1000 + rope_scaling=1.0 + is_neox_style=True + rotary_emb = DeepseekScalingRotaryEmbedding( + rope_head_dim, + rotary_dim, + rope_max_seq_len, + rope_base, + is_neox_style, + rope_scaling, + q.dtype, + device=q.device) + out = torch.empty(b, sq, hq, dv, dtype=q.dtype) + for i in range(b): + begin = i * max_seqlen_pad + end = begin + cache_seqlens[i] + ik = k.view(-1, hkv, d)[begin:end] + iv = v.view(-1, hkv, dv)[begin:end] + iq = q[i] + if use_rope: + q_nope, q_pe = iq.split([dv, rotary_dim], dim=-1) # [sq, hq, d] + k_nope, k_pe = ik.split([dv, rotary_dim], dim=-1) # [sk, hkv, d] + q_pe, k_pe = rotary_emb(positions[i], q_pe, k_pe) + iq[..., dv:]=q_pe + ik[..., dv:]=k_pe + O = scaled_dot_product_attention( + iq.transpose(0, 1), + ik.transpose(0, 1), + iv.transpose(0, 1), + hq=hq, + hkv=hkv, + is_causal=causal, + ) + out[i] = O.transpose(0, 1) + return out \ No newline at end of file diff --git a/problems/amd/mla/task.py b/problems/amd/mla/task.py new file mode 100644 index 0000000..e83654f --- /dev/null +++ b/problems/amd/mla/task.py @@ -0,0 +1,16 @@ +import torch +from typing import TypeVar, TypedDict + +input_t = TypeVar("input_t", bound=tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, int, torch.Tensor]) +output_t = TypeVar("output_t", bound=torch.Tensor) + +# Define test spec with parameters in the same order as in task.yml +class TestSpec(TypedDict): + b: int # batch size + d: int # dimension + dv: int # value dimension + hq: int # number of query heads + sq: int # query sequence length + hkv: int # number of key/value heads + meansk: int # mean kv sequence length + seed: int # random seed \ No newline at end of file diff --git a/problems/amd/mla/task.yml b/problems/amd/mla/task.yml new file mode 100644 index 0000000..49b93e4 --- /dev/null +++ b/problems/amd/mla/task.yml @@ -0,0 +1,71 @@ +# name: mla-py + +files: + - {"name": "submission.py", "source": "@SUBMISSION@"} + - {"name": "task.py", "source": "task.py"} + - {"name": "utils.py", "source": "../utils.py"} + - {"name": "reference.py", "source": "reference.py"} + - {"name": "eval.py", "source": "../eval.py"} + +lang: "py" + +description: | + You will implement a custom mla decode kernel optimized for MI300, a few things simplified here: + + 1. Q, K, V data type as bfloat16 + 2. provide Q, K, V hidden states directly, no Q, K, V up/down projections + 3. decode only with pre-allocated non-paged latent kv cache + 4. no need to update kv cache + 5. no need to implement RoPE in mla kernel, we only show its implementation in ref kernel + + The shapes of all outer and inner dimensions of tensors are from DeepSeek-R1, and split number of heads to fit in one GPU. + To be explicit, you will be given a tuple to tensors: + + ```yaml + q_input [B, q_seqlen, h_q, kv_lora_rank + qk_rope_hd] + k_input [B, kv_seqlen, 1, kv_lora_rank + qk_rope_hd] + v_input [B, kv_seqlen, 1, kv_lora_rank] + attn_output [B, q_seqlen, h_q, kv_lora_rank] + ``` + + where + + 0. B::128 # batch size + 1. kv_seqlen [1024, 6144] + 2. q_seqlen:: 1 # as only consider decoding + 3. qk_nope_head_dim:: 512 + 4. qk_rope_hd:: 64 + 5. kv_lora_rank(v_head_dim):: 512 + 6. h_q:: 128 # num of q heads + 7. h_kv:: 1 # as it's mla, kv head is 1 + + The ranking criteria is the geometric mean of the benchmark results. + + For the grand price, your kernel will be evaluated against the speed of light analysis + and the solution closest to the speed of light will be awarded the grand price. + + aiter performance for different kv_seqlen is below: + | batch | kv_seqlen | q_seqlen | dtype | aiter time(us) | + |---|---|---|---|---| + | 128 | 1024 | 1 | bf16 | 152.52 | + | 128 | 6144 | 1 | bf16 | 640.57 | + +config: + main: "eval.py" + +templates: + Python: "template.py" + +test_timeout: 900 +benchmark_timeout: 900 +ranked_timeout: 1200 + +tests: + - {"b": 128, "d": 576, "dv": 512, "hq": 128, "sq": 1, "hkv": 1, "meansk": 1024, "seed": 97} + - {"b": 128, "d": 576, "dv": 512, "hq": 128, "sq": 1, "hkv": 1, "meansk": 6144, "seed": 97} + +benchmarks: + - {"b": 128, "d": 576, "dv": 512, "hq": 128, "sq": 1, "hkv": 1, "meansk": 1024, "seed": 97} + - {"b": 128, "d": 576, "dv": 512, "hq": 128, "sq": 1, "hkv": 1, "meansk": 6144, "seed": 97} + +ranking_by: "geom" \ No newline at end of file diff --git a/problems/amd/identity/template.py b/problems/amd/mla/template.py similarity index 100% rename from problems/amd/identity/template.py rename to problems/amd/mla/template.py