Skip to content

Commit 1ecadee

Browse files
authored
Merge pull request #31 from HiLab-git/dev
Dev
2 parents a114a79 + 1e79947 commit 1ecadee

15 files changed

+235
-257
lines changed

README.md

+2-2
Original file line numberDiff line numberDiff line change
@@ -47,10 +47,10 @@ Run the following command to install the latest released version of PyMIC:
4747
```bash
4848
pip install PYMIC
4949
```
50-
To install a specific version of PYMIC such as 0.3.0, run:
50+
To install a specific version of PYMIC such as 0.3.1, run:
5151

5252
```bash
53-
pip install PYMIC==0.3.0
53+
pip install PYMIC==0.3.1
5454
```
5555
Alternatively, you can download the source code for the latest version. Run the following command to compile and install:
5656

pymic/loss/cls/basic.py

+2-8
Original file line numberDiff line numberDiff line change
@@ -65,10 +65,7 @@ def forward(self, loss_input_dict):
6565
labels = loss_input_dict['ground_truth'][:, None] # reshape to N, 1
6666
softmax = nn.Softmax(dim = 1)
6767
predict = softmax(predict)
68-
num_class = list(predict.size())[1]
69-
data_type = 'float' if(predict.dtype is torch.float32) else 'double'
70-
soft_y = get_soft_label(labels, num_class, data_type)
71-
loss = self.l1_loss(predict, soft_y)
68+
loss = self.l1_loss(predict, labels)
7269
return loss
7370

7471
class MSELoss(AbstractClassificationLoss):
@@ -84,10 +81,7 @@ def forward(self, loss_input_dict):
8481
labels = loss_input_dict['ground_truth'][:, None] # reshape to N, 1
8582
softmax = nn.Softmax(dim = 1)
8683
predict = softmax(predict)
87-
num_class = list(predict.size())[1]
88-
data_type = 'float' if(predict.dtype is torch.float32) else 'double'
89-
soft_y = get_soft_label(labels, num_class, data_type)
90-
loss = self.mse_loss(predict, soft_y)
84+
loss = self.mse_loss(predict, labels)
9185
return loss
9286

9387
class NLLLoss(AbstractClassificationLoss):

pymic/loss/seg/ce.py

-1
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@ class CrossEntropyLoss(AbstractSegLoss):
1818
"""
1919
def __init__(self, params = None):
2020
super(CrossEntropyLoss, self).__init__(params)
21-
2221

2322
def forward(self, loss_input_dict):
2423
predict = loss_input_dict['prediction']

pymic/loss/seg/ssl.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,9 @@
66
import torch.nn as nn
77
import numpy as np
88
from pymic.loss.seg.util import reshape_tensor_to_2D
9+
from pymic.loss.seg.abstract import AbstractSegLoss
910

10-
class EntropyLoss(nn.Module):
11+
class EntropyLoss(AbstractSegLoss):
1112
"""
1213
Entropy Minimization for segmentation tasks.
1314
The parameters should be written in the `params` dictionary, and it has the
@@ -43,7 +44,7 @@ def forward(self, loss_input_dict):
4344
avg_ent = torch.mean(entropy)
4445
return avg_ent
4546

46-
class TotalVariationLoss(nn.Module):
47+
class TotalVariationLoss(AbstractSegLoss):
4748
"""
4849
Total Variation Loss for segmentation tasks.
4950
The parameters should be written in the `params` dictionary, and it has the

pymic/net/cls/torch_pretrained_net.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ def __init__(self, params):
7575
def get_parameters_to_update(self):
7676
if(self.update_mode == "all"):
7777
return self.net.parameters()
78-
elif(self.update_layers == "last"):
78+
elif(self.update_mode == "last"):
7979
params = self.net.fc.parameters()
8080
if(self.in_chns !=3):
8181
# combining the two iterables into a single one
@@ -119,7 +119,7 @@ def get_parameters_to_update(self):
119119
params = self.net.classifier[-1].parameters()
120120
if(self.in_chns !=3):
121121
params = itertools.chain()
122-
for pram in [self.net.classifier[-1].parameters(), self.net.net.features[0].parameters()]:
122+
for pram in [self.net.classifier[-1].parameters(), self.net.features[0].parameters()]:
123123
params = itertools.chain(params, pram)
124124
return params
125125
else:
@@ -138,7 +138,7 @@ class MobileNetV2(BuiltInNet):
138138
as well as the first layer when `input_chns` is not 3.
139139
"""
140140
def __init__(self, params):
141-
super(MobileNetV2, self).__init__()
141+
super(MobileNetV2, self).__init__(params)
142142
self.net = models.mobilenet_v2(pretrained = self.pretrain)
143143

144144
# replace the last layer
@@ -157,7 +157,7 @@ def get_parameters_to_update(self):
157157
params = self.net.classifier[-1].parameters()
158158
if(self.in_chns !=3):
159159
params = itertools.chain()
160-
for pram in [self.net.classifier[-1].parameters(), self.net.net.features[0][0].parameters()]:
160+
for pram in [self.net.classifier[-1].parameters(), self.net.features[0][0].parameters()]:
161161
params = itertools.chain(params, pram)
162162
return params
163163
else:

pymic/net/net2d/unet2d_urpc.py

-132
This file was deleted.

pymic/net/net_dict_seg.py

-3
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
55
* UNet2D :mod:`pymic.net.net2d.unet2d.UNet2D`
66
* UNet2D_DualBranch :mod:`pymic.net.net2d.unet2d_dual_branch.UNet2D_DualBranch`
7-
* UNet2D_URPC :mod:`pymic.net.net2d.unet2d_urpc.UNet2D_URPC`
87
* UNet2D_CCT :mod:`pymic.net.net2d.unet2d_cct.UNet2D_CCT`
98
* UNet2D_ScSE :mod:`pymic.net.net2d.unet2d_scse.UNet2D_ScSE`
109
* AttentionUNet2D :mod:`pymic.net.net2d.unet2d_attention.AttentionUNet2D`
@@ -17,7 +16,6 @@
1716
from __future__ import print_function, division
1817
from pymic.net.net2d.unet2d import UNet2D
1918
from pymic.net.net2d.unet2d_dual_branch import UNet2D_DualBranch
20-
from pymic.net.net2d.unet2d_urpc import UNet2D_URPC
2119
from pymic.net.net2d.unet2d_cct import UNet2D_CCT
2220
from pymic.net.net2d.cople_net import COPLENet
2321
from pymic.net.net2d.unet2d_attention import AttentionUNet2D
@@ -30,7 +28,6 @@
3028
SegNetDict = {
3129
'UNet2D': UNet2D,
3230
'UNet2D_DualBranch': UNet2D_DualBranch,
33-
'UNet2D_URPC': UNet2D_URPC,
3431
'UNet2D_CCT': UNet2D_CCT,
3532
'COPLENet': COPLENet,
3633
'AttentionUNet2D': AttentionUNet2D,

0 commit comments

Comments
 (0)