Skip to content

Commit

Permalink
Fix whoopsie after merging upstream
Browse files Browse the repository at this point in the history
  • Loading branch information
umbertov committed Dec 23, 2020
1 parent c10799b commit 8d1452e
Showing 1 changed file with 9 additions and 7 deletions.
16 changes: 9 additions & 7 deletions byol_pytorch/byol_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -377,20 +377,22 @@ def update_moving_average(self):
def forward(self, x):
image_one, image_two = self.augment1(x), self.augment2(x)

online_proj_one = self.online_encoder(image_one)
online_proj_two = self.online_encoder(image_two)
online_proj_one, _ = self.online_encoder(image_one)
online_proj_two, _ = self.online_encoder(image_two)

z1 = self.online_predictor(online_proj_one)
z2 = self.online_predictor(online_proj_two)

with torch.no_grad():
target_encoder = self._get_target_encoder() if self.use_momentum else self.online_encoder
target_proj_one = target_encoder(image_one).detach()
target_proj_two = target_encoder(image_two).detach()
target_encoder = self._get_target_encoder() if self.use_momentum else self.online_encoder
target_proj_one, _ = target_encoder(image_one)
target_proj_two, _ = target_encoder(image_two)
target_proj_one.detach_()
target_proj_two.detach_()

loss = raft_loss(
z1, target_proj_one,
z2, target_proj_two,
z1, target_proj_one.detach(),
z2, target_proj_two.detach(),
)

return loss.mean()
Expand Down

0 comments on commit 8d1452e

Please sign in to comment.