Skip to content

Commit

Permalink
Add code for RAFT
Browse files Browse the repository at this point in the history
Add implementation for the algorithm described by the paper
"Run Away From your Teacher: Understanding BYOL by a Novel
Self-Supervised Approach" ( https://arxiv.org/abs/2011.10944 )

The RAFT class is essentially a copy-paste of the BYOL class,
with slight changes to the `forward` method, which computes a
different loss functionm making use of the new `raft_loss`
function. RAFT's loss is the difference of two losses: an
"alignment loss" between the projection of two different
augmented views of the same image, to be minimized, and the
"cross-model loss", to be maximized, which is the distance
between the online and target representation of the same input,
averaged over the two different views.
  • Loading branch information
umbertov committed Dec 23, 2020
1 parent 3d8a22a commit c10799b
Show file tree
Hide file tree
Showing 2 changed files with 157 additions and 6 deletions.
2 changes: 1 addition & 1 deletion byol_pytorch/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
from byol_pytorch.byol_pytorch import BYOL
from byol_pytorch.byol_pytorch import BYOL, RAFT
161 changes: 156 additions & 5 deletions byol_pytorch/byol_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,70 @@ def set_requires_grad(model, val):

# loss fn

def loss_fn(x, y):
def normalize_args(*args):
return [
x / F.normalize(x, dim=-1, p=2)
for x in args
]


def alignment_loss(x, y):
'''
Args:
x: predictor(online(t1(img)))
y: predictor(online(t2(img)))
Return:
The Alignment Loss as defined in the Run Away From Your Teacher Paper
'''
x, y = normalize_args(x,y)
return torch.norm(x-y)**2

def cross_model_loss(
pred_output_1, target_output_1,
pred_output_2, target_output_2):
'''
Let x1, x2 the result of applying two randomly sampled transforms
to the source image x
Args:
pred_output_1: predictor(online(x1))
target_output_1: target(x1)
pred_output_2: predictor(online(x2))
target_output_2: target(x2)
Return:
The Cross-Model Loss as defined in the Run Away From Your Teacher paper.
'''

pred_output_1, target_output_1, pred_output_2, target_output_2 = (
normalize_args(
pred_output_1, target_output_1, pred_output_2, target_output_2 )
)

term_1 = torch.norm(pred_output_1 - target_output_1)**2
term_2 = torch.norm(pred_output_2 - target_output_2)**2

return (term_1 - term_2) / 2

def raft_loss(
pred_output_1, target_output_1,
pred_output_2, target_output_2
):

pred_output_1, target_output_1, pred_output_2, target_output_2 = (
normalize_args(
pred_output_1, target_output_1, pred_output_2, target_output_2 )
)

loss_align = alignment_loss(pred_output_1, pred_output_2)
loss_cross = cross_model_loss(
pred_output_1, target_output_1,
pred_output_2, target_output_2
)

return loss_align + loss_cross


def byol_loss(x, y):
x = F.normalize(x, dim=-1, p=2)
y = F.normalize(y, dim=-1, p=2)
return 2 - 2 * (x * y).sum(dim=-1)
Expand Down Expand Up @@ -171,8 +234,7 @@ def __init__(

# default SimCLR augmentation

DEFAULT_AUG = torch.nn.Sequential(
RandomApply(
DEFAULT_AUG = torch.nn.Sequential( RandomApply(
T.ColorJitter(0.8, 0.8, 0.8, 0.2),
p = 0.3
),
Expand Down Expand Up @@ -240,8 +302,97 @@ def forward(self, x, return_embedding = False):
target_proj_one.detach_()
target_proj_two.detach_()

loss_one = loss_fn(online_pred_one, target_proj_two.detach())
loss_two = loss_fn(online_pred_two, target_proj_one.detach())
loss_one = byol_loss(online_pred_one, target_proj_two.detach())
loss_two = byol_loss(online_pred_two, target_proj_one.detach())

loss = loss_one + loss_two
return loss.mean()

class RAFT(nn.Module):
def __init__(
self,
net,
image_size,
hidden_layer = -2,
projection_size = 256,
projection_hidden_size = 4096,
augment_fn = None,
augment_fn2 = None,
moving_average_decay = 0.99,
use_momentum = True
):
super().__init__()

# default SimCLR augmentation

DEFAULT_AUG = torch.nn.Sequential( RandomApply(
T.ColorJitter(0.8, 0.8, 0.8, 0.2),
p = 0.3
),
T.RandomGrayscale(p=0.2),
T.RandomHorizontalFlip(),
RandomApply(
T.GaussianBlur((3, 3), (1.0, 2.0)),
p = 0.2
),
T.RandomResizedCrop((image_size, image_size)),
T.Normalize(
mean=torch.tensor([0.485, 0.456, 0.406]),
std=torch.tensor([0.229, 0.224, 0.225])),
)

self.augment1 = default(augment_fn, DEFAULT_AUG)
self.augment2 = default(augment_fn2, self.augment1)

self.online_encoder = NetWrapper(net, projection_size, projection_hidden_size, layer=hidden_layer)

self.use_momentum = use_momentum
self.target_encoder = None
self.target_ema_updater = EMA(moving_average_decay)

self.online_predictor = MLP(projection_size, projection_size, projection_hidden_size)

# get device of network and make wrapper same device
device = get_module_device(net)
self.to(device)

# send a mock image tensor to instantiate singleton parameters
self.forward(torch.randn(2, 3, image_size, image_size, device=device))

@singleton('target_encoder')
def _get_target_encoder(self):
target_encoder = copy.deepcopy(self.online_encoder)
set_requires_grad(target_encoder, False)
return target_encoder

def reset_moving_average(self):
del self.target_encoder
self.target_encoder = None

def update_moving_average(self):
assert self.use_momentum, 'you do not need to update the moving average, since you have turned off momentum for the target encoder'
assert self.target_encoder is not None, 'target encoder has not been created yet'
update_moving_average(self.target_ema_updater, self.target_encoder, self.online_encoder)

def forward(self, x):
image_one, image_two = self.augment1(x), self.augment2(x)

online_proj_one = self.online_encoder(image_one)
online_proj_two = self.online_encoder(image_two)

z1 = self.online_predictor(online_proj_one)
z2 = self.online_predictor(online_proj_two)

with torch.no_grad():
target_encoder = self._get_target_encoder() if self.use_momentum else self.online_encoder
target_proj_one = target_encoder(image_one).detach()
target_proj_two = target_encoder(image_two).detach()

loss = raft_loss(
z1, target_proj_one,
z2, target_proj_two,
)

return loss.mean()


0 comments on commit c10799b

Please sign in to comment.