-
Notifications
You must be signed in to change notification settings - Fork 1
/
pretrain_moco.py
153 lines (120 loc) · 6.89 KB
/
pretrain_moco.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
"""
part of the code adopted from https://github.com/lightly-ai/lightly
"""
import torch.nn as nn
import torch.utils
import torch
import copy
from models.ssl_utils import NTXentLoss, projection_mlp
from pretrain_template import ClientTemplate
def get_clients(encoder_local_top_bottom_list, encoder_local_top_list, encoder_cross_list, args):
client_list = []
for i in range(args.k):
client = SSClient(i, [encoder_local_top_bottom_list[i], encoder_local_top_list[i], encoder_cross_list[i]], args)
client_list.append(client)
return client_list
class SSClient(ClientTemplate):
def __init__(self, client_idx, models, args):
super(SSClient, self).__init__(client_idx, models, args)
# methods related models and optimizers
self.projection_mlp_cross = projection_mlp(args.num_ftrs, self.proj_hidden_dim, self.out_dim,
self.num_mlp_layers).to(args.device)
self.projection_mlp_local = projection_mlp(args.num_ftrs, self.proj_hidden_dim, self.out_dim,
self.num_mlp_layers).to(args.device)
# loss
self.cross_criterion = NTXentLoss().to(args.device)
self.local_criterion_one = NTXentLoss(memory_bank_size=4096).to(args.device)
self.local_criterion_two = NTXentLoss(memory_bank_size=4096).to(args.device)
self.local_criterion_constraint = NTXentLoss().to(args.device)
# setting for aggregation
self.model_local_top = nn.ModuleList(
[self.encoder_local_top, self.projection_mlp_local])
self.models = nn.ModuleList([self.encoder_cross, self.projection_mlp_cross,
self.encoder_local_bottom, self.encoder_local_top, self.projection_mlp_local])
# byol
self.encoder_local_bottom_momentum = copy.deepcopy(self.encoder_local_bottom)
self.encoder_local_top_momentum = copy.deepcopy(self.encoder_local_top)
self.projection_head_local_momentum = copy.deepcopy(self.projection_mlp_local)
deactivate_requires_grad(self.encoder_local_bottom_momentum)
deactivate_requires_grad(self.encoder_local_top_momentum)
deactivate_requires_grad(self.projection_head_local_momentum)
# rescale learning rate
self.pretrain_lr_ratio = 0.5/(self.args.local_ratio + self.args.constraint_ratio)
# get optimizer
self.optimizer_list_cross = [self.get_optimizer(self.encoder_cross, self.args.optimizer),
self.get_optimizer(self.projection_mlp_cross)]
self.optimizer_list_local = [self.get_optimizer(self.encoder_local_bottom, self.args.optimizer),
self.get_optimizer(self.encoder_local_top, self.args.optimizer),
self.get_optimizer(self.projection_mlp_local)]
# learning rate scheduler
self.scheduler_list = [
torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, args.epochs)
for optimizer in self.optimizer_list_cross[:-1] if optimizer is not None] +[
torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, args.epochs)
for optimizer in self.optimizer_list_local[:-1] if optimizer is not None]
self.model_to_device(args.device)
def get_local_projection_momentum_feature(self, x):
if isinstance(x, list):
f_local = self.encoder_local_bottom_momentum(x[0].float().to(self.args.device))
h_local = self.encoder_local_top_momentum(f_local).flatten(start_dim=1)
z_local = self.projection_head_local_momentum(h_local)
else:
f_local = self.encoder_local_bottom_momentum(x)
h_local = self.encoder_local_top_momentum(f_local).flatten(start_dim=1)
z_local = self.projection_head_local_momentum(h_local)
z_local = z_local.detach()
return z_local
def compute_cross_loss(self, x, y, z_cross_own, z_cross_received, epoch):
meta = {}
loss_debug = []
if isinstance(z_cross_received, list):
loss = torch.tensor(0)
for z_item in z_cross_received:
loss_ind = self.cross_criterion(z_cross_own, z_item.detach())
loss_debug.append(loss_ind.item())
loss = loss + loss_ind
loss = loss / len(z_cross_received)
else:
loss = self.cross_criterion(z_cross_own, z_cross_received.detach())
meta['loss_debug'] = loss_debug
return loss, loss_debug
def compute_local_loss(self, x, y, epoch=0):
meta = {}
loss_constraint = torch.tensor(0).to(self.args.device)
update_momentum(self.encoder_local_bottom, self.encoder_local_bottom_momentum, m=0.999)
update_momentum(self.encoder_local_top, self.encoder_local_top_momentum, m=0.999)
update_momentum(self.projection_mlp_local, self.projection_head_local_momentum, m=0.999)
# two augemented views
x1 = x[0].float().to(self.args.device)
x2 = x[1].float().to(self.args.device)
# model outputs of two augemented views
f1 = self.encoder_local_bottom(x1)
h1 = self.encoder_local_top(f1).flatten(start_dim=1)
z1 = self.projection_mlp_local(h1)
f2 = self.encoder_local_bottom(x2)
h2 = self.encoder_local_top(f2).flatten(start_dim=1)
z2 = self.projection_mlp_local(h2)
z1_momentum = self.get_local_projection_momentum_feature(x1)
z2_momentum = self.get_local_projection_momentum_feature(x2)
# local contrastive loss using augmented views
loss = self.args.local_ratio * (self.local_criterion_one(z1, z2_momentum.detach()) +
self.local_criterion_two(z2, z1_momentum.detach()))
# regularization loss of cross encoder
if self.args.constraint_ratio > 0:
h1_cross = self.encoder_cross(x1).flatten(start_dim=1)
h2_cross = self.encoder_cross(x2).flatten(start_dim=1)
z1_cross = self.projection_mlp_cross(h1_cross)
z2_cross = self.projection_mlp_cross(h2_cross)
loss_constraint = self.args.constraint_ratio * (self.local_criterion_one(z1, z1_cross.detach()) +
self.local_criterion_two(z2, z2_cross.detach()))
loss = loss + loss_constraint
meta['loss_debug'] = loss_constraint.item()
return loss, meta
def deactivate_requires_grad(model: nn.Module):
"""Deactivates the requires_grad flag for all parameters of a model."""
for param in model.parameters():
param.requires_grad = False
def update_momentum(model: nn.Module, model_ema: nn.Module, m: float):
"""Updates parameters of `model_ema` with Exponential Moving Average of `model`"""
for model_ema, model in zip(model_ema.parameters(), model.parameters()):
model_ema.data = model_ema.data * m + model.data * (1. - m)