-
Notifications
You must be signed in to change notification settings - Fork 0
/
LTH_fig4a_oneshot_random.py
65 lines (48 loc) · 2.75 KB
/
LTH_fig4a_oneshot_random.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
import torch
from torch import nn
import torchvision
import matplotlib.pyplot as plt
from d2l import torch as d2l
from dl_assignment_7_common import * # Your functions should go here if you want to use them from scripts
import os
import numpy as np
import torch.nn.utils.prune as prune
print('start')
run_name = 'oneshot_random'
cp_path = f'./checkpoints/{run_name}'
if not os.path.exists(cp_path):
os.makedirs(cp_path)
result_path = f'./results/{run_name}'
if not os.path.exists(result_path):
os.makedirs(result_path)
device = d2l.try_gpu()
dataset_used = get_dataset('mnist', dir = './data', batch_size = 60, shuffle = True, download = False)
runs = 3
remaining_percentages = np.array([100, 70, 50, 15, 10, 7, 4, 3.5, 2, 1, 0.5,])
epochs = [16, 16, 16, 16, 16, 18, 24, 24, 36, 40, 46]
prune_fractions = (100-remaining_percentages)/100
early_stop_iterations = np.zeros([len(prune_fractions),runs])
early_stop_iterations = np.insert(early_stop_iterations, 0, prune_fractions, axis=1)
early_stop_trainacc = np.zeros([len(prune_fractions),runs])
early_stop_trainacc = np.insert(early_stop_trainacc, 0, prune_fractions, axis=1)
early_stop_testacc = np.zeros([len(prune_fractions),runs])
early_stop_testacc = np.insert(early_stop_testacc, 0, prune_fractions, axis=1)
net, optimizer = create_network(arch = 'LeNet', input = 784, output = 10)
_, early_stop_values = train(net, optimizer, dataset_used, epochs = 10, file_specifier = f'LTH_fig4a_base', val_interval = 2, cp_path=cp_path , plot = False)
for j in range(runs):
for i, fraction in enumerate(prune_fractions):
print(f'run {j} for fraction {fraction}')
trained_net = torch.load(f"{cp_path}/model_LeNet-after-LTH_fig4a_base.pth")
net, optimizer = create_network(arch = 'LeNet', input = 784, output = 10)
net.load_state_dict(trained_net.state_dict())
mask = L1_prune(net, fraction)
net, optimizer = create_network(arch = 'LeNet', input = 784, output = 10)
net_pruned = prune_using_mask(net, mask)
optimizer = torch.optim.Adam(net_pruned.parameters(), lr=0.0012)
_, early_stop_values = train(net_pruned, optimizer, dataset_used, epochs = epochs[i], file_specifier = f'LTH_4a_L1_pruned{fraction}', val_interval = 2, cp_path=cp_path , plot = False)
early_stop_iterations[i,j+1] = early_stop_values['iteration']
early_stop_trainacc[i,j+1] = early_stop_values['train_acc']
early_stop_testacc[i,j+1] = early_stop_values['test_acc']
np.savetxt(f'{result_path}/early_stop_iterations_{run_name}.txt',early_stop_iterations)
np.savetxt(f'{result_path}/early_stop_trainacc_{run_name}.txt',early_stop_trainacc)
np.savetxt(f'{result_path}/early_stop_testacc_{run_name}.txt',early_stop_testacc)