Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Seg example #42

Open
wants to merge 53 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 3 commits
Commits
Show all changes
53 commits
Select commit Hold shift + click to select a range
70aaa31
Segmentation example - 1st commit
Mar 15, 2022
c8674b0
fix elastic augmentation and loss bce+dice; update for new eval packa…
Mar 15, 2022
98ebeee
Change to eval package
Mar 20, 2022
62d3c1c
Merge branch 'master' into seg_example
mosheraboh Mar 25, 2022
26c1ffe
move example to a new folder
Mar 26, 2022
59afc55
Update the create_data script + add some comment regarding the origin…
Mar 26, 2022
310726e
remove old script and update main script according to PR comments (no…
Mar 26, 2022
76f834d
change names to eval*
Mar 26, 2022
13aa1da
Changes following the comments on the PR
Mar 27, 2022
4ad76f2
remove commented code
Apr 4, 2022
323c270
Merge branch 'master' into seg_example
Apr 4, 2022
960a484
Merge branch 'master' into seg_example
mosheraboh Apr 7, 2022
5a4a6cf
change input desc to file names and processor to compute mask images …
Apr 10, 2022
e1e601d
factor out end to end examples to seprate package
Apr 13, 2022
75ce22b
add examples to PYTHONPATH
Apr 13, 2022
c53c96c
run unittests in examples
Apr 13, 2022
3551533
remove fuse1 data package
Apr 14, 2022
97eac6b
remove dataset from manager
Apr 14, 2022
dddac6d
convert mnist to fuse2 style
Apr 14, 2022
83936d0
create dl package
Apr 17, 2022
69a3b09
remove Fuse prefix
Apr 17, 2022
5771c85
reorg examples
Apr 17, 2022
13c5f0b
Merge branch 'master' of github.com:IBM/fuse-med-ml into fuse2
Apr 17, 2022
c12ff08
Merge branch 'fuse2' of github.com:IBM/fuse-med-ml into data_package
Apr 17, 2022
3e3a476
add fuse data package
Apr 17, 2022
5f614a0
Merge branch 'data_package' of github.com:IBM/fuse-med-ml into mnist_…
Apr 17, 2022
9b511cc
adjust mnist runner
Apr 17, 2022
532b347
imaging extension
Apr 18, 2022
1a94187
Merge branch 'data_package' of github.com:IBM/fuse-med-ml into mnist_…
Apr 18, 2022
92d78d1
Move changes from master's branch to mnist_fuse2_style's branch
Apr 18, 2022
0b8b573
Fixed import path
Apr 18, 2022
fb72386
remove the create-data script and move all its functionality to input…
Apr 19, 2022
bab4bf7
Updated the notebook (mnist example) to fuse2
Apr 25, 2022
d4a5c02
Skip test - temp
Apr 25, 2022
68df39c
Update test_notebook_hello_world.py
SagiPolaczek Apr 25, 2022
ae69698
Data package (#61)
mosheraboh Apr 28, 2022
b4804c3
skip test (it works locally)
Apr 28, 2022
63f2a29
Move changes from master's branch to mnist_fuse2_style's branch
Apr 18, 2022
c40f129
Fixed import path
Apr 18, 2022
419a36b
Updated the notebook (mnist example) to fuse2
Apr 25, 2022
13536f0
Skip test - temp
Apr 25, 2022
1077b8c
Update test_notebook_hello_world.py
SagiPolaczek Apr 25, 2022
85cdae4
skip test (it works locally)
Apr 28, 2022
26f78f8
Merge branch 'hello_world_unittest' of github.com:IBM/fuse-med-ml int…
Apr 28, 2022
0f3abf3
Fixed override in the set_device functionality and made cpu usage mor…
Apr 28, 2022
5d0b65a
Merge pull request #63 from IBM/hello_world_unittest
SagiPolaczek Apr 28, 2022
3f71935
data package readme
mosheraboh May 2, 2022
17ef080
Merge with master
May 3, 2022
1dc5050
merged with fuse2
May 3, 2022
2e1f98c
Fix import for fuse2 + add a static pipeline + change data source int…
May 10, 2022
dbe44a5
Complete data pipeline including the dynamic part
May 10, 2022
6c4232c
Working fuse2 version + fix to gaussian op data type
May 17, 2022
0ef6742
remove comments and non-required input processor file
May 17, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 51 additions & 16 deletions fuse/data/augmentor/augmentor_toolbox.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,12 +27,16 @@
from scipy.ndimage.filters import gaussian_filter
from scipy.ndimage.interpolation import map_coordinates
from torch import Tensor
import elasticdeform as ed
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's a new dependency - add it to requirements.txt
If there are more, please add them too.


from fuse.utils.utils_param_sampler import FuseUtilsParamSamplerGaussianPatch as Gaussian
from fuse.utils.utils_param_sampler import FuseUtilsParamSamplerRandBool as RandBool
from fuse.utils.utils_param_sampler import FuseUtilsParamSamplerRandInt as RandInt
from fuse.utils.utils_param_sampler import FuseUtilsParamSamplerUniform as Uniform

import PIL
import torch.nn.functional as F


######## Affine augmentation
def aug_op_affine(aug_input: Tensor, rotate: float = 0.0, translate: Tuple[float, float] = (0.0, 0.0),
Expand Down Expand Up @@ -63,7 +67,12 @@ def aug_op_affine(aug_input: Tensor, rotate: float = 0.0, translate: Tuple[float
for channel in channels:
aug_channel_tensor = aug_input[channel].numpy()
aug_channel_tensor = Image.fromarray(aug_channel_tensor)
aug_channel_tensor = TTF.affine(aug_channel_tensor, angle=rotate, scale=scale, translate=translate, shear=shear)
aug_channel_tensor = TTF.affine(aug_channel_tensor,
angle=rotate,
scale=scale,
resample=PIL.Image.BILINEAR,
translate=translate,
shear=shear)
if flip[0]:
aug_channel_tensor = TTF.vflip(aug_channel_tensor)
if flip[1]:
Expand Down Expand Up @@ -265,23 +274,13 @@ def aug_op_elastic_transform(aug_input: Tensor, alpha: float = 1, sigma: float =
:param channels: which channels to apply the augmentation
:return distorted image
"""
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why did you change it? the previous implementation didn't work?
I see that you are not using channels - can you support it as well?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I find it more straightforward and easy to use.

random_state = numpy.random.RandomState(None)
if channels is None:
channels = list(range(aug_input.shape[0]))
aug_tensor = aug_input.numpy()
for channel in channels:
aug_channel_tensor = aug_input[channel].numpy()
shape = aug_channel_tensor.shape
dx1 = gaussian_filter((random_state.rand(*shape) * 2 - 1), sigma, mode="constant", cval=0) * alpha
dx2 = gaussian_filter((random_state.rand(*shape) * 2 - 1), sigma, mode="constant", cval=0) * alpha
# convert back to torch tensor
aug_input = [numpy.array(t) for t in aug_input]
aug_input_d = ed.deform_random_grid(aug_input, sigma=7, points=3, axis=[(1, 2), (1,2)])
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Don't use fixed values - use the function arguments


x1, x2 = numpy.meshgrid(numpy.arange(shape[0]), numpy.arange(shape[1]))
indices = numpy.reshape(x2 + dx2, (-1, 1)), numpy.reshape(x1 + dx1, (-1, 1))
aug_output = [torch.from_numpy(t) for t in aug_input_d]

distored_image = map_coordinates(aug_channel_tensor, indices, order=1, mode='reflect')
distored_image = distored_image.reshape(aug_channel_tensor.shape)
aug_tensor[channel] = distored_image
return torch.from_numpy(aug_tensor)
return aug_output


######### Default / Example augmentation pipline for a 2D image
Expand Down Expand Up @@ -452,3 +451,39 @@ def aug_op_batch_mix_up(aug_input: Tuple[Tensor, Tensor], factor: float) -> Tupl
img = img * (1.0 - factor) + factor * img_mix_up
labels = labels * (1.0 - factor) + factor * labels_mix_up
return img, labels


def aug_op_random_crop_and_resize(aug_input: Tensor,
out_size,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

missing type annotation - you have more though - can you please go over the new code and add type annotations?

crop_size: float = 1.0, # or optional - Tuple[float, float]
x_off: float = 1.0,
y_off: float = 1.0,
z_off: float = 1.0) -> Tensor:
"""
random crop a (3d) tensor and resize it to a given size
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add here aug_input and which dimensions you expect

:param crop_size: float <= 1.0 - the fraction to crop from the original tensor for each dim
:param x_off: float <= 1.0 - the x-offset to take
:param y_off:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add doc

:param z_off:
:param out_size: the size of the output tensor
:return: the output tensor
"""
in_shape = aug_input.shape

if len(aug_input.shape) == 4:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

else not supported error?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added support for 2d tensor. Need your help on how to handle other dimensions (error?)

ch, z, y, x = in_shape

x_width = int(crop_size * x)
x_off = int(x_off * (x - x_width))

y_width = int(crop_size * y)
y_off = int(y_off * (y - y_width))

z_width = int(crop_size * z)
z_off = int(z_off * (z - z_width))

aug_tensor = aug_input[:, z_off:z_off+z_width, y_off:y_off+y_width, x_off:x_off+x_width]

aug_tensor = F.interpolate(aug_tensor, out_size)

return aug_tensor
76 changes: 75 additions & 1 deletion fuse/losses/segmentation/loss_dice.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ def __call__(self, predict, target):
predict = predict.contiguous().view(predict.shape[0], -1)
target = target.contiguous().view(target.shape[0], -1)

if target.dtype == torch.int64:
if target.dtype == torch.int64 or target.dtype == torch.int32:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we care about the type?
Can we convert it anyway?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I also think that the to device is not necessary here - can you remove it>

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I remember I had to add this option. not sure what was the case.
I removed both the to(device) .

target = target.type(torch.float32).to(target.device)
num = 2*torch.sum(torch.mul(predict, target), dim=1) + self.eps
den = torch.sum(predict.pow(self.p) + target.pow(self.p), dim=1) + self.eps
Expand All @@ -82,6 +82,80 @@ def __call__(self, predict, target):
raise Exception('Unexpected reduction {}'.format(self.reduction))


class DiceBCELoss(FuseLossBase):

def __init__(self,
pred_name,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

missing type annotations

target_name,
filter_func: Optional[Callable]=None,
class_weights=None,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

type annotation

bce_weight: float=1.0,
power: int=1,
eps: float=1.,
reduction: str='mean'):
'''
Compute a weighted sum of dice-loss and cross entropy loss.

:param pred_name: batch_dict key for predicted output (e.g., class probabilities after softmax).
Expected Tensor shape = [batch, num_classes, height, width]
:param target_name: batch_dict key for target (e.g., ground truth label). Expected Tensor shape = [batch, height, width]
:param filter_func: function that filters batch_dict/ The function gets ans input batch_dict and returns filtered batch_dict
:param class_weights: An array of shape [num_classes,]
:param bce_weight: weight to attach to the bce loss, default : 1.0
:param power: Denominator value: \sum{x^p} + \sum{y^p}, default: 1
:param eps: A float number to smooth loss, and avoid NaN error, default: 1
:param reduction: Reduction method to apply, return mean over batch if 'mean',
return sum if 'sum', return a tensor of shape [N,] if 'none'

Returns: Loss tensor according to arg reduction
Raise: Exception if unexpected reduction
'''

super().__init__(pred_name, target_name, 1.0)
self.class_weights = class_weights
self.bce_weight = bce_weight
self.filter_func = filter_func
self.dice = BinaryDiceLoss(power, eps, reduction)

def __call__(self, batch_dict):

if self.filter_func is not None:
batch_dict = self.filter_func(batch_dict)
predict = FuseUtilsHierarchicalDict.get(batch_dict, self.pred_name).float()
target = FuseUtilsHierarchicalDict.get(batch_dict, self.target_name).long()

target = target.type(torch.float32).to(target.device)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is it really required?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Im not sure...


total_loss = 0
n_classes = predict.shape[1]

# Convert target to one hot encoding
if n_classes > 1 and target.shape[1] != n_classes:
target = make_one_hot(target, n_classes)

assert predict.shape == target.shape, 'predict & target shape do not match'

# import ipdb; ipdb.set_trace(context=7) # BREAKPOINT
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove

total_class_weights = sum(self.class_weights) if self.class_weights is not None else n_classes
for cls_index in range(n_classes):
dice_loss = self.dice(predict[:, cls_index, :, :], target[:, cls_index, :, :])
if self.bce_weight > 0.0:
bce_loss = F.binary_cross_entropy(predict[:, cls_index, :, :].view(-1),
target[:, cls_index, :, :].view(-1),
reduction='mean')
dice_loss += self.bce_weight * bce_loss

if self.class_weights is not None:
assert self.class_weights.shape[0] == n_classes, \
'Expect weight shape [{}], got[{}]'.format(n_classes, self.class_weights.shape[0])
dice_loss *= self.class_weights[cls_index]

total_loss += dice_loss
total_loss /= total_class_weights

return self.weight*total_loss


class FuseDiceLoss(FuseLossBase):

def __init__(self, pred_name,
Expand Down
Loading