Skip to content

Commit

Permalink
change name
Browse files Browse the repository at this point in the history
  • Loading branch information
SangbumChoi committed Sep 3, 2024
1 parent 68da46a commit 72f8fcb
Showing 1 changed file with 9 additions and 9 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -244,23 +244,23 @@ def __init__(self, config: ViTPoseBackboneConfig) -> None:
experts = [nn.Linear(hidden_features, part_features) for _ in range(num_experts)]
self.experts = nn.ModuleList(experts)

def forward(self, x, indices):
expert_x = torch.zeros_like(x[:, :, -self.part_features :])
def forward(self, hidden_state, indices):
expert_hidden_state = torch.zeros_like(hidden_state[:, :, -self.part_features :])

x = self.fc1(x)
x = self.act(x)
shared_x = self.fc2(x)
hidden_state = self.fc1(hidden_state)
hidden_state = self.act(hidden_state)
shared_hidden_state = self.fc2(hidden_state)
indices = indices.view(-1, 1, 1)

# to support ddp training
for i in range(self.num_experts):
selectedIndex = indices == i
current_x = self.experts[i](x) * selectedIndex
expert_x = expert_x + current_x
current_hidden_state = self.experts[i](hidden_state) * selectedIndex
expert_hidden_state = expert_hidden_state + current_hidden_state

x = torch.cat([shared_x, expert_x], dim=-1)
hidden_state = torch.cat([shared_hidden_state, expert_hidden_state], dim=-1)

return x
return hidden_state


class ViTPoseBackboneMLP(nn.Module):
Expand Down

0 comments on commit 72f8fcb

Please sign in to comment.