-
Notifications
You must be signed in to change notification settings - Fork 19
/
pyhessian.py
224 lines (181 loc) · 7.44 KB
/
pyhessian.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
# Credit to: https://github.com/amirgholami/PyHessian
import torch
import math
from torch.autograd import Variable
import numpy as np
def group_product(xs, ys):
"""
the inner product of two lists of variables xs,ys
:param xs:
:param ys:
:return:
"""
return sum([torch.sum(x * y) for (x, y) in zip(xs, ys)])
def group_add(params, update, alpha=1):
"""
params = params + update*alpha
:param params: list of variable
:param update: list of data
:return:
"""
for i, p in enumerate(params):
params[i].data.add_(update[i] * alpha)
return params
def normalization(v):
"""
normalization of a list of vectors
return: normalized vectors v
"""
s = group_product(v, v)
s = s**0.5
s = s.cpu().item()
v = [vi / (s + 1e-6) for vi in v]
return v
def get_params_grad(model):
"""
get model parameters and corresponding gradients
"""
params = []
grads = []
for param in model.parameters():
if not param.requires_grad:
continue
params.append(param)
grads.append(0. if param.grad is None else param.grad + 0.)
return params, grads
def hessian_vector_product(gradsH, params, v):
"""
compute the hessian vector product of Hv, where
gradsH is the gradient at the current point,
params is the corresponding variables,
v is the vector.
"""
hv = torch.autograd.grad(gradsH,
params,
grad_outputs=v,
only_inputs=True,
retain_graph=True)
return hv
def orthnormal(w, v_list):
"""
make vector w orthogonal to each vector in v_list.
afterwards, normalize the output w
"""
for v in v_list:
w = group_add(w, v, alpha=-group_product(w, v))
return normalization(w)
class hessian():
"""
The class used to compute :
i) the top 1 (n) eigenvalue(s) of the neural network
ii) the trace of the entire neural network
iii) the estimated eigenvalue density
"""
def __init__(self, model, data=None, dataloader=None, cuda=False):
"""
model: the model that needs Hessain information
criterion: the loss function
data: a single batch of data, including inputs and its corresponding labels
dataloader: the data loader including bunch of batches of data
"""
# make sure we either pass a single batch or a dataloader
assert (data != None and dataloader == None) or (data == None and
dataloader != None)
self.model = model.eval() # make model is in evaluation model
if data != None:
self.data = data
self.full_dataset = False
else:
self.data = dataloader
self.full_dataset = True
if cuda:
self.device = 'cuda'
else:
self.device = 'cpu'
# pre-processing for single batch case to simplify the computation.
if not self.full_dataset:
self.x, self.t = self.data
if self.device == 'cuda':
self.x, self.t = self.x.cuda(), self.t.cuda()
# if we only compute the Hessian information for a single batch data, we can re-use the gradients.
outputs = self.model(self.x, self.t)
loss = self.criterion(self.x, self.t, outputs)
loss.backward(create_graph=True)
# this step is used to extract the parameters from the model
params, gradsH = get_params_grad(self.model)
self.params = params
self.gradsH = gradsH # gradient used for Hessian computation
def criterion(self, x, t, u):
u_x = torch.autograd.grad(u, x, grad_outputs=torch.ones_like(u), retain_graph=True, create_graph=True)[0]
u_xx = torch.autograd.grad(u_x, x, grad_outputs=torch.ones_like(u_x), retain_graph=True, create_graph=True)[0]
u_t = torch.autograd.grad(u, t, grad_outputs=torch.ones_like(u), retain_graph=True, create_graph=True)[0]
u_tt = torch.autograd.grad(u_t, t, grad_outputs=torch.ones_like(u_t), retain_graph=True, create_graph=True)[0]
pi = torch.tensor(np.pi)
loss_res = torch.mean((
u_xx + u_tt + u + (pi**2 + (pi * 4)**2 - 1) \
* torch.sin(pi * x) * torch.sin(pi * 4 * t))**2)
return loss_res
def dataloader_hv_product(self, v):
device = self.device
num_data = 0 # count the number of datum points in the dataloader
THv = [torch.zeros(p.size()).to(device) for p in self.params
] # accumulate result
for x, t in self.data:
self.model.zero_grad()
tmp_num_data = x.size(0)
outputs = self.model(x.to(device), t.to(device))
loss = self.criterion(x,t,outputs)
loss.backward(create_graph=True)
params, gradsH = get_params_grad(self.model)
self.model.zero_grad()
Hv = torch.autograd.grad(gradsH,
params,
grad_outputs=v,
only_inputs=True,
retain_graph=False)
THv = [
THv1 + Hv1 * float(tmp_num_data) + 0.
for THv1, Hv1 in zip(THv, Hv)
]
num_data += float(tmp_num_data)
THv = [THv1 / float(num_data) for THv1 in THv]
eigenvalue = group_product(THv, v).cpu().item()
return eigenvalue, THv
def eigenvalues(self, maxIter=100, tol=1e-3, top_n=1):
"""
compute the top_n eigenvalues using power iteration method
maxIter: maximum iterations used to compute each single eigenvalue
tol: the relative tolerance between two consecutive eigenvalue computations from power iteration
top_n: top top_n eigenvalues will be computed
"""
assert top_n >= 1
device = self.device
eigenvalues = []
eigenvectors = []
computed_dim = 0
while computed_dim < top_n:
eigenvalue = None
v = [torch.randn(p.size()).to(device) for p in self.params
] # generate random vector
v = normalization(v) # normalize the vector
for i in range(maxIter):
v = orthnormal(v, eigenvectors)
self.model.zero_grad()
if self.full_dataset:
tmp_eigenvalue, Hv = self.dataloader_hv_product(v)
else:
Hv = hessian_vector_product(self.gradsH, self.params, v)
tmp_eigenvalue = group_product(Hv, v).cpu().item()
v = normalization(Hv)
if eigenvalue == None:
eigenvalue = tmp_eigenvalue
else:
if abs(eigenvalue - tmp_eigenvalue) / (abs(eigenvalue) +
1e-6) < tol:
break
else:
eigenvalue = tmp_eigenvalue
eigenvalues.append(eigenvalue)
eigenvectors.append(v)
computed_dim += 1
return eigenvalues, eigenvectors