forked from lzx325/COVID-19-repo
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtest.py
114 lines (104 loc) · 4.1 KB
/
test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
import numpy as np
import scipy
import matplotlib.pyplot as plt
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms
import glob
import shutil
import pprint
import scipy.stats
import scipy.ndimage
import U_net_Model as unm
def predict_one():
from U_net_predict import final_prediction
from visualize_mask_and_raw import visualize_mask_and_raw_array
from U_net_predict import f1_score_evaluation
data_root="./datasets/example"
lung_mask_root="./datasets/example/"
checkpoint_root="./checkpoint/"
visualization_root="./prediction_visualization/"
prediction_root="./prediction/"
filename="example.npy"
lung_mask_fn="example_lung-mask.npy"
out_filename="prediction-example.npy"
visualization=True
threshold=2
best_model_fns={
'X':os.path.join(checkpoint_root,"best_model-X.pth"),
'Y':os.path.join(checkpoint_root,"best_model-Y.pth"),
'Z':os.path.join(checkpoint_root,"best_model-Z.pth")
}
visualization_dir=visualization_root+out_filename+'/'
if filename.endswith(".npz"):
data_array=np.load(os.path.join(data_root,filename))["example"]
elif filename.endswith(".npy"):
data_array=np.load(os.path.join(data_root,filename))
else:
assert False
if not os.path.isdir(checkpoint_root):
os.mkdir(checkpoint_root)
if not os.path.isdir(visualization_root):
os.mkdir(visualization_root)
if not os.path.isdir(prediction_root):
os.mkdir(prediction_root)
image=data_array[:,:,:,0]
gt_mask=data_array[:,:,:,1]
lung_mask=np.load(os.path.join(lung_mask_root,lung_mask_fn))
prediction_file=os.path.join(prediction_root,out_filename)
if os.path.isfile(prediction_file):
print("Load from saved prediction file: %s"%(prediction_file))
pred=np.load(prediction_file)
else:
print("Predicting %s"%(filename))
pred=final_prediction(image,best_model_fns,threshold=threshold,lung_mask=lung_mask)
pred=pred*lung_mask
np.save(prediction_file,pred)
print("Computing F1 score")
f1_score_evaluation(pred,gt_mask, lung_mask,filter_ground_truth=True)
if visualization:
print("Preparing visualization")
visualize_mask_and_raw_array(visualization_dir,image,pred,gt_mask)
def predict_all():
from U_net_predict import final_prediction
from visualize_mask_and_raw import visualize_mask_and_raw_array
from U_net_predict import f1_score_evaluation
data_root="./arrays_raw"
checkpoint_root="./checkpoint/"
prediction_root="./prediction/"
filenames=os.listdir(data_root)
for filename in filenames:
noext=os.path.splitext(filename)[0]
out_filename="%s-prediction.npz"%(noext)
threshold=2
best_model_fns={
'X':os.path.join(checkpoint_root,"best_model-X.pth"),
'Y':os.path.join(checkpoint_root,"best_model-Y.pth"),
'Z':os.path.join(checkpoint_root,"best_model-Z.pth")
}
if filename.endswith(".npz"):
data_array=np.load(os.path.join(data_root,filename))["example"]
elif filename.endswith(".npy"):
data_array=np.load(os.path.join(data_root,filename))
else:
assert False
if not os.path.isdir(checkpoint_root):
os.mkdir(checkpoint_root)
if not os.path.isdir(prediction_root):
os.mkdir(prediction_root)
if len(data_array.shape)==4:
image=data_array[:,:,:,0]
elif len(data_array.shape)==3:
image=data_array
prediction_file=os.path.join(prediction_root,out_filename)
if os.path.isfile(prediction_file):
print("Load from saved prediction file: %s"%(prediction_file))
pred=np.load(prediction_file)
else:
print("Predicting %s"%(filename))
pred=final_prediction(image,best_model_fns,threshold=threshold)
np.savez(prediction_file,array=pred)
if __name__=="__main__":
predict_one()