-
Notifications
You must be signed in to change notification settings - Fork 2
/
basic.py
36 lines (26 loc) · 1022 Bytes
/
basic.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
# common things.
import torch
torch.manual_seed(42)
torch.cuda.manual_seed(42)
torch.cuda.manual_seed_all(42)
class DummyModel(torch.nn.Module):
def __init__(self):
super(DummyModel, self).__init__()
self.fc1 = torch.nn.Linear(10, 173)
self.fc2 = torch.nn.Linear(173, 1)
# self.fc3 = torch.nn.Linear(100, 100)
def forward(self, x):
x = torch.relu(self.fc1(x))
x = self.fc2(x)
return x
@torch.no_grad()
def check_model_from_reference(model):
ref_state_dict = torch.load("model_ref.pth", map_location="cpu")
model_state_dict = model.state_dict()
for k in ref_state_dict.keys():
ref = ref_state_dict[k].float()
current = model_state_dict[k].cpu().float()
assert torch.allclose(
ref, current, atol=1e-2
), f"Model state dict does not match the reference model state dict for key {k}. Difference: {(ref - current).abs().max()}"
print("Model state dict matches the reference model state dict")