-
Notifications
You must be signed in to change notification settings - Fork 0
/
compute_stats_experiments.py
57 lines (48 loc) · 1.85 KB
/
compute_stats_experiments.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
import glob
import pickle
from tqdm import tqdm
import numpy as np
import cv2
def compute_mean_std(paths, mean=None, std=None):
nb_channels = 6
count = np.zeros(nb_channels)
sum_x = np.zeros(nb_channels)
sum_x2 = np.zeros(nb_channels)
for path in tqdm(paths, desc='Imgs'):
channel = int(path.split('_')[2][1])-1
im = cv2.imread(path, cv2.IMREAD_GRAYSCALE)/255
if (mean is not None) and (std is not None):
im = (im-mean[channel]) / std[channel]
count[channel] += 1
sum_x[channel] += np.sum(im)
sum_x2[channel] += np.sum(im**2)
count = count*512*512
mean = sum_x / count
std = np.sqrt((sum_x2/count) - mean**2)
return mean, std
FILENAME = 'stats_experiments.pickle'
experiments_train = glob.glob('data/train/*/', recursive=True)
experiments_train = [exp.split('/')[-2] for exp in experiments_train]
experiments_test = glob.glob('data/test/*/', recursive=True)
experiments_test = [exp.split('/')[-2] for exp in experiments_test]
experiments = experiments_train + experiments_test
stats_experiments = dict()
for experiment in tqdm(experiments, desc='experiments'):
paths = glob.glob('data/*/'+experiment+'/*/*.jpeg', recursive=True)
mean, std = compute_mean_std(paths)
stats_experiments[experiment] = dict()
stats_experiments[experiment]['mean'] = mean
stats_experiments[experiment]['std'] = std
with open(FILENAME, 'wb') as f:
pickle.dump(stats_experiments, f)
print()
print('Verification:')
with open(FILENAME, 'rb') as f:
stats_experiments = pickle.load(f)
for experiment in experiments:
paths = glob.glob('data/*/'+experiment+'/*/*.jpeg', recursive=True)
mean = stats_experiments[experiment]['mean']
std = stats_experiments[experiment]['std']
mean, std = compute_mean_std(paths, mean=mean, std=std)
print('mean=', mean)
print('std=', std)