Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions configs/moe_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ class MoEModelConfig:
v_dim: int | None = 128
batch_size: int = 24
max_steps: int = 1000
use_mem_efficient_attention: bool = False

# Training parameters
gradient_accumulation_steps: int = 4
Expand Down
7 changes: 3 additions & 4 deletions models/components.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,18 +2,17 @@
import torch.nn as nn
import torch.nn.functional as F
from typing import Tuple, Optional

from xformers.ops import SwiGLU

class Expert(nn.Module):
"""Single expert network (essentially a FeedForward layer)"""
def __init__(self, d_model: int, d_ff: int, dropout: float = 0.1):
super().__init__()
self.linear1 = nn.Linear(d_model, d_ff, bias=False)
self.linear2 = nn.Linear(d_ff, d_model, bias=False)
self.ffn = SwiGLU(d_model, d_ff, d_model, bias=False)
self.dropout = nn.Dropout(dropout)

def forward(self, x):
return self.linear2(self.dropout(F.silu(self.linear1(x))))
return self.dropout(self.ffn(x))


class TopKRouter(nn.Module):
Expand Down
58 changes: 47 additions & 11 deletions models/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,15 @@
import torch.nn.functional as F
from torchtune.modules import RotaryPositionalEmbeddings
from .components import MixtureOfExperts
from xformers.ops import memory_efficient_attention, LowerTriangularMask


class Rotary(nn.Module):
def __init__(self, dim: int, max_seq_len: int):
super().__init__()
self.rope = RotaryPositionalEmbeddings(dim=dim, max_seq_len=max_seq_len, base=10000)
self.rope = RotaryPositionalEmbeddings(
dim=dim, max_seq_len=max_seq_len, base=10000
)

def forward(self, x_BTHD: torch.Tensor):
# x_BTHD shape: [B, T, H, D] - need to convert to [B, T, H, D] for torchtune
Expand All @@ -18,8 +21,16 @@ def forward(self, x_BTHD: torch.Tensor):


class MultiHeadAttention(nn.Module):
def __init__(self, d_model: int, n_heads: int, max_seq_len: int, dropout: float = 0.1):
def __init__(
self,
d_model: int,
n_heads: int,
max_seq_len: int,
dropout: float = 0.1,
use_mem_atten: bool = False,
):
super().__init__()
self.mem_atten = use_mem_atten
self.d_model = d_model
self.n_heads = n_heads
self.d_k = d_model // n_heads
Expand All @@ -37,18 +48,28 @@ def forward(self, x):

qkv = self.qkv(x).reshape(batch_size, seq_len, 3, self.n_heads, self.d_k)
qkv = qkv.permute(2, 0, 3, 1, 4)
Q, K, V = qkv[0], qkv[1], qkv[2] # [B, H, T, D]
Q, K, V = qkv[0], qkv[1], qkv[2] # [B, H, T, D]

# Q = self.rotary(Q)
# K = self.rotary(K)
# Apply RoPE on [B, T, H, D]
Q = self.rotary(Q.transpose(1, 2)).transpose(1, 2)
K = self.rotary(K.transpose(1, 2)).transpose(1, 2)

attn_output = F.scaled_dot_product_attention(
Q, K, V, is_causal=True, dropout_p=self.dropout if self.training else 0.0
if self.mem_atten:
attn_output = memory_efficient_attention(
Q, K, V, attn_bias=LowerTriangularMask(), p=self.dropout
)
else:
attn_output = F.scaled_dot_product_attention(
Q,
K,
V,
is_causal=True,
dropout_p=self.dropout if self.training else 0.0,
)
attn_output = attn_output.transpose(1, 2).reshape(
batch_size, seq_len, self.d_model
)
attn_output = attn_output.transpose(1, 2).reshape(batch_size, seq_len, self.d_model)
# attn_output = attn_output.transpose(1, 2).reshape(B, T, self.d_model)
return self.w_o(attn_output)

Expand All @@ -64,8 +85,10 @@ def __init__(
v_dim: int,
max_seq_len: int,
dropout: float = 0.1,
use_mem_atten: bool = False,
):
super().__init__()
self.mem_atten = use_mem_atten
self.d_model = d_model
self.n_heads = n_heads
self.qk_dim = qk_rope_dim + qk_nope_dim
Expand Down Expand Up @@ -107,9 +130,18 @@ def forward(self, x: torch.Tensor):
k_nope, v = torch.split(kv, (self.qk_nope_dim, self.v_dim), dim=-1)
k = torch.cat([k_nope, k_rope.expand(-1, -1, self.n_heads, -1)], dim=-1)

attn_output = F.scaled_dot_product_attention(
q, k, v, is_causal=True, dropout_p=self.dropout if self.training else 0.0
)
if self.mem_atten:
attn_output = memory_efficient_attention(
q, k, v, attn_bias=LowerTriangularMask(), p=self.dropout
)
else:
attn_output = F.scaled_dot_product_attention(
q,
k,
v,
is_causal=True,
dropout_p=self.dropout if self.training else 0.0,
)
attn_output = (
attn_output.transpose(1, 2).contiguous().view(batch_size, seq_len, -1)
)
Expand All @@ -133,6 +165,7 @@ def __init__(
num_experts: int = 8,
top_k: int = 2,
dropout: float = 0.1,
use_mem_atten: bool = False,
):
super().__init__()

Expand All @@ -147,9 +180,12 @@ def __init__(
v_dim,
max_seq_len,
dropout,
use_mem_atten,
)
else:
self.attention = MultiHeadAttention(d_model, n_heads, max_seq_len, dropout)
self.attention = MultiHeadAttention(
d_model, n_heads, max_seq_len, dropout, use_mem_atten
)

# MoE layer
self.feed_forward = MixtureOfExperts(d_model, d_ff, num_experts, top_k, dropout)
Expand Down
2 changes: 1 addition & 1 deletion models/moe_llm.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import torch
import torch.nn as nn
import math
from typing import Optional
from configs.moe_config import MoEModelConfig
from models.layers import MoETransformerBlock

Expand Down Expand Up @@ -33,6 +32,7 @@ def __init__(self, config: MoEModelConfig):
config.num_experts,
config.expert_top_k,
config.dropout,
config.use_mem_efficient_attention,
)
for i in range(config.n_layers)
]
Expand Down
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ transformers
torchtune
torchao
matplotlib
xformers
# lm-eval

# Single T4 GPU training
Expand Down