-
Notifications
You must be signed in to change notification settings - Fork 35
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Seg example #42
base: master
Are you sure you want to change the base?
Seg example #42
Changes from 3 commits
70aaa31
c8674b0
98ebeee
62d3c1c
26c1ffe
59afc55
310726e
76f834d
13aa1da
4ad76f2
323c270
960a484
5a4a6cf
e1e601d
75ce22b
c53c96c
3551533
97eac6b
dddac6d
83936d0
69a3b09
5771c85
13c5f0b
c12ff08
3e3a476
5f614a0
9b511cc
532b347
1a94187
92d78d1
0b8b573
fb72386
bab4bf7
d4a5c02
68df39c
ae69698
b4804c3
63f2a29
c40f129
419a36b
13536f0
1077b8c
85cdae4
26f78f8
0f3abf3
5d0b65a
3f71935
17ef080
1dc5050
2e1f98c
dbe44a5
6c4232c
0ef6742
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -27,12 +27,16 @@ | |
from scipy.ndimage.filters import gaussian_filter | ||
from scipy.ndimage.interpolation import map_coordinates | ||
from torch import Tensor | ||
import elasticdeform as ed | ||
|
||
from fuse.utils.utils_param_sampler import FuseUtilsParamSamplerGaussianPatch as Gaussian | ||
from fuse.utils.utils_param_sampler import FuseUtilsParamSamplerRandBool as RandBool | ||
from fuse.utils.utils_param_sampler import FuseUtilsParamSamplerRandInt as RandInt | ||
from fuse.utils.utils_param_sampler import FuseUtilsParamSamplerUniform as Uniform | ||
|
||
import PIL | ||
import torch.nn.functional as F | ||
|
||
|
||
######## Affine augmentation | ||
def aug_op_affine(aug_input: Tensor, rotate: float = 0.0, translate: Tuple[float, float] = (0.0, 0.0), | ||
|
@@ -63,7 +67,12 @@ def aug_op_affine(aug_input: Tensor, rotate: float = 0.0, translate: Tuple[float | |
for channel in channels: | ||
aug_channel_tensor = aug_input[channel].numpy() | ||
aug_channel_tensor = Image.fromarray(aug_channel_tensor) | ||
aug_channel_tensor = TTF.affine(aug_channel_tensor, angle=rotate, scale=scale, translate=translate, shear=shear) | ||
aug_channel_tensor = TTF.affine(aug_channel_tensor, | ||
angle=rotate, | ||
scale=scale, | ||
resample=PIL.Image.BILINEAR, | ||
translate=translate, | ||
shear=shear) | ||
if flip[0]: | ||
aug_channel_tensor = TTF.vflip(aug_channel_tensor) | ||
if flip[1]: | ||
|
@@ -265,23 +274,13 @@ def aug_op_elastic_transform(aug_input: Tensor, alpha: float = 1, sigma: float = | |
:param channels: which channels to apply the augmentation | ||
:return distorted image | ||
""" | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why did you change it? the previous implementation didn't work? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I find it more straightforward and easy to use. |
||
random_state = numpy.random.RandomState(None) | ||
if channels is None: | ||
channels = list(range(aug_input.shape[0])) | ||
aug_tensor = aug_input.numpy() | ||
for channel in channels: | ||
aug_channel_tensor = aug_input[channel].numpy() | ||
shape = aug_channel_tensor.shape | ||
dx1 = gaussian_filter((random_state.rand(*shape) * 2 - 1), sigma, mode="constant", cval=0) * alpha | ||
dx2 = gaussian_filter((random_state.rand(*shape) * 2 - 1), sigma, mode="constant", cval=0) * alpha | ||
# convert back to torch tensor | ||
aug_input = [numpy.array(t) for t in aug_input] | ||
aug_input_d = ed.deform_random_grid(aug_input, sigma=7, points=3, axis=[(1, 2), (1,2)]) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Don't use fixed values - use the function arguments |
||
|
||
x1, x2 = numpy.meshgrid(numpy.arange(shape[0]), numpy.arange(shape[1])) | ||
indices = numpy.reshape(x2 + dx2, (-1, 1)), numpy.reshape(x1 + dx1, (-1, 1)) | ||
aug_output = [torch.from_numpy(t) for t in aug_input_d] | ||
|
||
distored_image = map_coordinates(aug_channel_tensor, indices, order=1, mode='reflect') | ||
distored_image = distored_image.reshape(aug_channel_tensor.shape) | ||
aug_tensor[channel] = distored_image | ||
return torch.from_numpy(aug_tensor) | ||
return aug_output | ||
|
||
|
||
######### Default / Example augmentation pipline for a 2D image | ||
|
@@ -452,3 +451,39 @@ def aug_op_batch_mix_up(aug_input: Tuple[Tensor, Tensor], factor: float) -> Tupl | |
img = img * (1.0 - factor) + factor * img_mix_up | ||
labels = labels * (1.0 - factor) + factor * labels_mix_up | ||
return img, labels | ||
|
||
|
||
def aug_op_random_crop_and_resize(aug_input: Tensor, | ||
out_size, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. missing type annotation - you have more though - can you please go over the new code and add type annotations? |
||
crop_size: float = 1.0, # or optional - Tuple[float, float] | ||
x_off: float = 1.0, | ||
y_off: float = 1.0, | ||
z_off: float = 1.0) -> Tensor: | ||
""" | ||
random crop a (3d) tensor and resize it to a given size | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Add here aug_input and which dimensions you expect |
||
:param crop_size: float <= 1.0 - the fraction to crop from the original tensor for each dim | ||
:param x_off: float <= 1.0 - the x-offset to take | ||
:param y_off: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. add doc |
||
:param z_off: | ||
:param out_size: the size of the output tensor | ||
:return: the output tensor | ||
""" | ||
in_shape = aug_input.shape | ||
|
||
if len(aug_input.shape) == 4: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. else not supported error? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I added support for 2d tensor. Need your help on how to handle other dimensions (error?) |
||
ch, z, y, x = in_shape | ||
|
||
x_width = int(crop_size * x) | ||
x_off = int(x_off * (x - x_width)) | ||
|
||
y_width = int(crop_size * y) | ||
y_off = int(y_off * (y - y_width)) | ||
|
||
z_width = int(crop_size * z) | ||
z_off = int(z_off * (z - z_width)) | ||
|
||
aug_tensor = aug_input[:, z_off:z_off+z_width, y_off:y_off+y_width, x_off:x_off+x_width] | ||
|
||
aug_tensor = F.interpolate(aug_tensor, out_size) | ||
|
||
return aug_tensor |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -65,7 +65,7 @@ def __call__(self, predict, target): | |
predict = predict.contiguous().view(predict.shape[0], -1) | ||
target = target.contiguous().view(target.shape[0], -1) | ||
|
||
if target.dtype == torch.int64: | ||
if target.dtype == torch.int64 or target.dtype == torch.int32: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why do we care about the type? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I also think that the to device is not necessary here - can you remove it> There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I remember I had to add this option. not sure what was the case. |
||
target = target.type(torch.float32).to(target.device) | ||
num = 2*torch.sum(torch.mul(predict, target), dim=1) + self.eps | ||
den = torch.sum(predict.pow(self.p) + target.pow(self.p), dim=1) + self.eps | ||
|
@@ -82,6 +82,80 @@ def __call__(self, predict, target): | |
raise Exception('Unexpected reduction {}'.format(self.reduction)) | ||
|
||
|
||
class DiceBCELoss(FuseLossBase): | ||
|
||
def __init__(self, | ||
pred_name, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. missing type annotations |
||
target_name, | ||
filter_func: Optional[Callable]=None, | ||
class_weights=None, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. type annotation |
||
bce_weight: float=1.0, | ||
power: int=1, | ||
eps: float=1., | ||
reduction: str='mean'): | ||
''' | ||
Compute a weighted sum of dice-loss and cross entropy loss. | ||
|
||
:param pred_name: batch_dict key for predicted output (e.g., class probabilities after softmax). | ||
Expected Tensor shape = [batch, num_classes, height, width] | ||
:param target_name: batch_dict key for target (e.g., ground truth label). Expected Tensor shape = [batch, height, width] | ||
:param filter_func: function that filters batch_dict/ The function gets ans input batch_dict and returns filtered batch_dict | ||
:param class_weights: An array of shape [num_classes,] | ||
:param bce_weight: weight to attach to the bce loss, default : 1.0 | ||
:param power: Denominator value: \sum{x^p} + \sum{y^p}, default: 1 | ||
:param eps: A float number to smooth loss, and avoid NaN error, default: 1 | ||
:param reduction: Reduction method to apply, return mean over batch if 'mean', | ||
return sum if 'sum', return a tensor of shape [N,] if 'none' | ||
|
||
Returns: Loss tensor according to arg reduction | ||
Raise: Exception if unexpected reduction | ||
''' | ||
|
||
super().__init__(pred_name, target_name, 1.0) | ||
self.class_weights = class_weights | ||
self.bce_weight = bce_weight | ||
self.filter_func = filter_func | ||
self.dice = BinaryDiceLoss(power, eps, reduction) | ||
|
||
def __call__(self, batch_dict): | ||
|
||
if self.filter_func is not None: | ||
batch_dict = self.filter_func(batch_dict) | ||
predict = FuseUtilsHierarchicalDict.get(batch_dict, self.pred_name).float() | ||
target = FuseUtilsHierarchicalDict.get(batch_dict, self.target_name).long() | ||
|
||
target = target.type(torch.float32).to(target.device) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. is it really required? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Im not sure... |
||
|
||
total_loss = 0 | ||
n_classes = predict.shape[1] | ||
|
||
# Convert target to one hot encoding | ||
if n_classes > 1 and target.shape[1] != n_classes: | ||
target = make_one_hot(target, n_classes) | ||
|
||
assert predict.shape == target.shape, 'predict & target shape do not match' | ||
|
||
# import ipdb; ipdb.set_trace(context=7) # BREAKPOINT | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. remove |
||
total_class_weights = sum(self.class_weights) if self.class_weights is not None else n_classes | ||
for cls_index in range(n_classes): | ||
dice_loss = self.dice(predict[:, cls_index, :, :], target[:, cls_index, :, :]) | ||
if self.bce_weight > 0.0: | ||
bce_loss = F.binary_cross_entropy(predict[:, cls_index, :, :].view(-1), | ||
target[:, cls_index, :, :].view(-1), | ||
reduction='mean') | ||
dice_loss += self.bce_weight * bce_loss | ||
|
||
if self.class_weights is not None: | ||
assert self.class_weights.shape[0] == n_classes, \ | ||
'Expect weight shape [{}], got[{}]'.format(n_classes, self.class_weights.shape[0]) | ||
dice_loss *= self.class_weights[cls_index] | ||
|
||
total_loss += dice_loss | ||
total_loss /= total_class_weights | ||
|
||
return self.weight*total_loss | ||
|
||
|
||
class FuseDiceLoss(FuseLossBase): | ||
|
||
def __init__(self, pred_name, | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's a new dependency - add it to requirements.txt
If there are more, please add them too.