Skip to content

Commit

Permalink
Refactoring class hierarchy for FSDP wrapping
Browse files Browse the repository at this point in the history
  • Loading branch information
tgale96 committed Oct 24, 2023
1 parent 52aa1b2 commit b3ab365
Show file tree
Hide file tree
Showing 6 changed files with 76 additions and 32 deletions.
29 changes: 27 additions & 2 deletions megablocks/layers/dmoe.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from megablocks.layers import mlp
from megablocks.layers import moe
from megablocks.layers import mpu
from megablocks.layers import router
from megablocks.layers.arguments import Arguments
import megablocks.ops as ops
import numpy as np
Expand All @@ -13,10 +14,10 @@ def promote_scalar(x):
return x.view(1) if not len(x.size()) else x


class dMoE(moe.MoE):
class ParallelDroplessMLP(moe.ParallelMLP):

def __init__(self, args : Arguments):
super(dMoE, self).__init__(args)
super(ParallelDroplessMLP, self).__init__(args)
self.hidden_size = args.hidden_size
self.ffn_hidden_size = mpu.features_per_rank(args)
self.blocking = 128
Expand Down Expand Up @@ -307,3 +308,27 @@ def permute_and_compute(
bins,
expert_capactiy,
top_k)


class dMoE(torch.nn.Module):

def __init__(self, args : Arguments):
super(dMoE, self).__init__()

# Token router.
self.router = router.LearnedRouter(args)

# Expert computation helper.
self.experts = ParallelDroplessMLP(args)

def forward(self, x):
# NOTE: If we're going to cast the activations to lower precision
# do it before we permute the tokens to save bandwidth.
x = common.cast_if_autocast_enabled(x)
sl, bs, hs = x.size()

# Compute the expert scores and assignments.
scores, expert_weights, top_experts = self.router(x)

# Compute the experts.
return self.experts(x, scores, expert_weights, top_experts)
12 changes: 6 additions & 6 deletions megablocks/layers/dmoe_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,14 +44,14 @@ def test_modules(

# Set the baseline parameters to match exactly.
with torch.no_grad():
ne, hs, fhs = moe_mlp.mlp.w1.size()
w1 = dmoe_mlp.mlp.w1.view([ne, fhs, hs])
moe_mlp.mlp.w1.copy_(torch.transpose(w1, 1, 2).contiguous())
moe_mlp.mlp.w2.copy_(dmoe_mlp.mlp.w2.view([ne, fhs, hs]))
ne, hs, fhs = moe_mlp.experts.mlp.w1.size()
w1 = dmoe_mlp.experts.mlp.w1.view([ne, fhs, hs])
moe_mlp.experts.mlp.w1.copy_(torch.transpose(w1, 1, 2).contiguous())
moe_mlp.experts.mlp.w2.copy_(dmoe_mlp.experts.mlp.w2.view([ne, fhs, hs]))
moe_mlp.router.layer.weight.copy_(dmoe_mlp.router.layer.weight)
if moe_num_experts == 1:
mlp.w1.copy_(moe_mlp.mlp.w1.squeeze())
mlp.w2.copy_(moe_mlp.mlp.w2.squeeze())
mlp.w1.copy_(moe_mlp.experts.mlp.w1.squeeze())
mlp.w2.copy_(moe_mlp.experts.mlp.w2.squeeze())
return args, mlp, moe_mlp, dmoe_mlp

# min size: (1, 2, 128, 2, 1)
Expand Down
8 changes: 4 additions & 4 deletions megablocks/layers/memory_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,8 +65,8 @@ def test_memory(
# Calculate weight and gradient memory usage.
weight_memory = 2 * (
layer.router.layer.weight.numel() +
layer.mlp.w1.numel() +
layer.mlp.w2.numel())
layer.experts.mlp.w1.numel() +
layer.experts.mlp.w2.numel())

def grad_numel(x):
if x.grad is not None:
Expand All @@ -75,8 +75,8 @@ def grad_numel(x):

grad_memory = 2 * (
grad_numel(layer.router.layer.weight) +
grad_numel(layer.mlp.w1) +
grad_numel(layer.mlp.w2))
grad_numel(layer.experts.mlp.w1) +
grad_numel(layer.experts.mlp.w2))
weight_memory += grad_memory

print("Weight Memory Allocated = {:0.0f}MiB".format(
Expand Down
43 changes: 31 additions & 12 deletions megablocks/layers/moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,10 +94,14 @@ def batched_load_balancing_loss(args : Arguments):
return scale * torch.dot(tokens_per_expert, expert_scores)


class MoE(torch.nn.Module):
# NOTE: This class defines MoE expert computation, including expert model parallel
# communication. When using FSDP on top of MegaBlocks this is the module that should
# be wrapped s.t. the weight all-gathers can be scheduled *before* the expert model
# parallel all2all.
class ParallelMLP(torch.nn.Module):

def __init__(self, args : Arguments):
super(MoE, self).__init__()
super(ParallelMLP, self).__init__()
self.args = args

# Calculate the number of experts in total and the number of experts
Expand All @@ -110,9 +114,6 @@ def __init__(self, args : Arguments):
# so that we can pass it to radix sort.
self.sort_end_bit = max(int(np.ceil(np.log2(self.num_experts))), 1)

# Token router.
self.router = router.LearnedRouter(args)

# Expert MLP.
self.mlp = mlp.MLP(args)

Expand Down Expand Up @@ -410,15 +411,9 @@ def parallel_forward_once(self, x, expert_weights, top_experts):
self.args.quantize_scatter_num_bits)
return x, tokens_per_expert.flatten()

def forward(self, x):
# NOTE: If we're going to cast the activations to lower precision
# do it before we permute the tokens to save bandwidth.
x = common.cast_if_autocast_enabled(x)
def forward(self, x, scores, expert_weights, top_experts):
sl, bs, hs = x.size()

# Compute the expert scores and assignments.
scores, expert_weights, top_experts = self.router(x)

# Compute the experts.
x, tokens_per_expert = self.forward_fn(
x, expert_weights, top_experts)
Expand All @@ -429,3 +424,27 @@ def forward(self, x):
return x, self.bias
return x + self.bias
return x


class MoE(torch.nn.Module):

def __init__(self, args : Arguments):
super(MoE, self).__init__()

# Token router.
self.router = router.LearnedRouter(args)

# Expert computation helper.
self.experts = ParallelMLP(args)

def forward(self, x):
# NOTE: If we're going to cast the activations to lower precision
# do it before we permute the tokens to save bandwidth.
x = common.cast_if_autocast_enabled(x)
sl, bs, hs = x.size()

# Compute the expert scores and assignments.
scores, expert_weights, top_experts = self.router(x)

# Compute the experts.
return self.experts(x, scores, expert_weights, top_experts)
8 changes: 4 additions & 4 deletions megablocks/layers/moe_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,8 @@ def test_modules(
# Set the baseline parameters to match exactly.
if moe_num_experts == 1:
with torch.no_grad():
mlp.w1.copy_(moe_mlp.mlp.w1.squeeze())
mlp.w2.copy_(moe_mlp.mlp.w2.squeeze())
mlp.w1.copy_(moe_mlp.experts.mlp.w1.squeeze())
mlp.w2.copy_(moe_mlp.experts.mlp.w2.squeeze())
return args, mlp, moe_mlp


Expand Down Expand Up @@ -126,8 +126,8 @@ def testMoE_ForwardBackwardVersusDense(self, bs, sl, hs):
out, _ = moe_mlp(x)
loss = out.sum()
loss.backward()
w1_grad = moe_mlp.mlp.w1.grad.detach().squeeze()
w2_grad = moe_mlp.mlp.w2.grad.detach().squeeze()
w1_grad = moe_mlp.experts.mlp.w1.grad.detach().squeeze()
w2_grad = moe_mlp.experts.mlp.w2.grad.detach().squeeze()
moe_mlp.zero_grad(set_to_none=True)
x.grad = None
moe.clear_load_balancing_loss()
Expand Down
8 changes: 4 additions & 4 deletions megablocks/layers/parallelism_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,16 +96,16 @@ def permute(x):
out = x.view(hsd, esd, -1).transpose(1, 0).contiguous()
return out.view(num_experts * ffn_hidden_size, hidden_size)

wp_w2_grad = gather(wp.mlp.w2.grad)
ep_w2_grad = permute(gather(ep.mlp.w2.grad))
wp_w2_grad = gather(wp.experts.mlp.w2.grad)
ep_w2_grad = permute(gather(ep.experts.mlp.w2.grad))
if rank == 0:
np.testing.assert_allclose(
wp_w2_grad.float().cpu(),
ep_w2_grad.float().cpu(),
rtol=1e-5, atol=1e-5)

wp_w1_grad = gather(wp.mlp.w1.grad)
ep_w1_grad = permute(gather(ep.mlp.w1.grad))
wp_w1_grad = gather(wp.experts.mlp.w1.grad)
ep_w1_grad = permute(gather(ep.experts.mlp.w1.grad))
if rank == 0:
np.testing.assert_allclose(
wp_w1_grad.float().cpu(),
Expand Down

0 comments on commit b3ab365

Please sign in to comment.