forked from 8DM20-group6/ProstateSegmentation
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathutils.py
143 lines (119 loc) · 4.85 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
import json
import time
import torch
import torch.nn as nn
import matplotlib
import matplotlib.pyplot as plt
import numpy as np
from pathlib import Path
from collections import OrderedDict
from models.vae import get_noise
class Logger():
"""Plots images during training for visualization.
"""
def __init__(self, config, model, train_loader=None):
self.config = config["train"]
self.decay_lr_after = self.config["decay_lr_after"]
self.epochs = self.config["epochs"]
self.results_dir = self.determine_dir(model)
print(f"Logging to {self.results_dir}")
self.vis = next(iter(train_loader))
def determine_dir(self, model):
if model.__class__.__name__=="UNet":
RESULTS_DIR = "segmentation_results"
self.modelname = "UNet"
self.lr = self.config["lr_unet"]
elif model.__class__.__name__=="VAE":
RESULTS_DIR = "vae_results"
self.modelname = "VAE"
self.lr = self.config["lr_vae"]
self.noise = get_noise(32, self.config["z_dim"], device=self.config["device"])
else:
raise Exception("What model is this bro?")
timestr = time.strftime("%Y%m%d_%H%M%S")
return Path.cwd() / RESULTS_DIR / f"{timestr}_epochs{self.epochs}_lr{self.lr}_decay{self.decay_lr_after}"
def visualize_train(self, model, epoch):
self.results_dir.mkdir(parents=True, exist_ok=True)
if self.modelname=="UNet":
predict_logits, _, _, _= model(self.vis[0].to(self.config["device"]))
heatmap = torch.sigmoid(predict_logits)
fig, axs = plt.subplots(1, 3)
axs[0].imshow(self.vis[0][0,:,:,:].squeeze().detach().cpu(), cmap="gray")
axs[1].imshow(self.vis[1][0,:,:,:].squeeze().detach().cpu(), cmap="gray")
axs[2].imshow(heatmap[0,:,:,:].squeeze().detach().cpu(), cmap="hot")
matplotlib.use('Agg')
vis_dir = self.results_dir / "train_imgs"
vis_dir.mkdir(parents=True, exist_ok=True)
plt.savefig(f"{vis_dir}/{epoch}.png")
plt.close()
if self.modelname=="VAE":
decoder = model.generator
img_generated = decoder(self.noise) # (32, 1, 64, 64)
matplotlib.use('Agg')
plt.imshow(img_generated[0,0,:,:].detach().cpu(), cmap="gray")
vis_dir = self.results_dir / "train_imgs"
vis_dir.mkdir(parents=True, exist_ok=True)
plt.savefig(f"{vis_dir}/{epoch}_img.png")
plt.close()
decoder_mask = model.generator_mask
mask_generated = decoder_mask(self.noise)
matplotlib.use('Agg')
plt.imshow(np.round(mask_generated[0,0,:,:].detach().cpu()), "gray")
vis_dir = self.results_dir / "train_imgs"
vis_dir.mkdir(parents=True, exist_ok=True)
plt.savefig(f"{vis_dir}/{epoch}_mask.png")
plt.close()
class DiceBCELoss(nn.Module):
"""Loss function, computed as the sum of Dice score and binary cross-entropy.
Notes
-----
This loss assumes that the inputs are logits (i.e., the outputs of a linear layer),
and that the targets are integer values that represent the correct class labels.
"""
def __init__(self):
super(DiceBCELoss, self).__init__()
def forward(self, outputs, targets, smooth=1):
"""Calculates segmentation loss for training
Parameters
----------
outputs : torch.Tensor
predictions of segmentation model
targets : torch.Tensor
ground-truth labels
smooth : float
smooth parameter for dice score avoids division by zero, by default 1
Returns
-------
float
the sum of the dice loss and binary cross-entropy
"""
outputs = torch.sigmoid(outputs)
# flatten label and prediction tensors
outputs = outputs.view(-1)
targets = targets.view(-1)
# compute Dice
intersection = (outputs * targets).sum()
dice_loss = 1 - (2.0 * intersection + smooth) / (
outputs.sum() + targets.sum() + smooth
)
BCE = nn.functional.binary_cross_entropy(outputs, targets, reduction="mean")
return BCE + dice_loss
def load_config(filename):
"""Loads config from .json file
Arguments:
filename (string): path to .json file
Returns:
config dictionary
"""
filename = Path(filename)
with filename.open('rt') as handle:
return json.load(handle, object_hook=OrderedDict)
def write_config(content, filename):
"""Writes dictionary to .json file
Arguments:
content (dictionary)
filename (string to path)
"""
filename = Path(filename)
with filename.open('wt') as handle:
json.dump(content, handle, indent=4, sort_keys=False)