-
Notifications
You must be signed in to change notification settings - Fork 34
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Seg example #42
base: master
Are you sure you want to change the base?
Seg example #42
Changes from 9 commits
70aaa31
c8674b0
98ebeee
62d3c1c
26c1ffe
59afc55
310726e
76f834d
13aa1da
4ad76f2
323c270
960a484
5a4a6cf
e1e601d
75ce22b
c53c96c
3551533
97eac6b
dddac6d
83936d0
69a3b09
5771c85
13c5f0b
c12ff08
3e3a476
5f614a0
9b511cc
532b347
1a94187
92d78d1
0b8b573
fb72386
bab4bf7
d4a5c02
68df39c
ae69698
b4804c3
63f2a29
c40f129
419a36b
13536f0
1077b8c
85cdae4
26f78f8
0f3abf3
5d0b65a
3f71935
17ef080
1dc5050
2e1f98c
dbe44a5
6c4232c
0ef6742
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -18,7 +18,7 @@ | |
""" | ||
|
||
from copy import deepcopy | ||
from typing import Tuple, Any, List, Iterable, Optional | ||
from typing import Tuple, Any, List, Iterable, Optional, Union | ||
|
||
import numpy | ||
import torch | ||
|
@@ -27,12 +27,16 @@ | |
from scipy.ndimage.filters import gaussian_filter | ||
from scipy.ndimage.interpolation import map_coordinates | ||
from torch import Tensor | ||
import elasticdeform as ed | ||
|
||
from fuse.utils.utils_param_sampler import FuseUtilsParamSamplerGaussianPatch as Gaussian | ||
from fuse.utils.utils_param_sampler import FuseUtilsParamSamplerRandBool as RandBool | ||
from fuse.utils.utils_param_sampler import FuseUtilsParamSamplerRandInt as RandInt | ||
from fuse.utils.utils_param_sampler import FuseUtilsParamSamplerUniform as Uniform | ||
|
||
import PIL | ||
import torch.nn.functional as F | ||
|
||
|
||
######## Affine augmentation | ||
def aug_op_affine(aug_input: Tensor, rotate: float = 0.0, translate: Tuple[float, float] = (0.0, 0.0), | ||
|
@@ -63,7 +67,12 @@ def aug_op_affine(aug_input: Tensor, rotate: float = 0.0, translate: Tuple[float | |
for channel in channels: | ||
aug_channel_tensor = aug_input[channel].numpy() | ||
aug_channel_tensor = Image.fromarray(aug_channel_tensor) | ||
aug_channel_tensor = TTF.affine(aug_channel_tensor, angle=rotate, scale=scale, translate=translate, shear=shear) | ||
aug_channel_tensor = TTF.affine(aug_channel_tensor, | ||
angle=rotate, | ||
scale=scale, | ||
resample=PIL.Image.BILINEAR, | ||
translate=translate, | ||
shear=shear) | ||
if flip[0]: | ||
aug_channel_tensor = TTF.vflip(aug_channel_tensor) | ||
if flip[1]: | ||
|
@@ -255,33 +264,27 @@ def aug_op_gaussian(aug_input: Tensor, mean: float = 0.0, std: float = 0.03, cha | |
return aug_tensor | ||
|
||
|
||
def aug_op_elastic_transform(aug_input: Tensor, alpha: float = 1, sigma: float = 50, channels: Optional[List[int]] = None): | ||
def aug_op_elastic_transform(aug_input: Tuple[Tensor], | ||
sigma: float = 50, | ||
num_points: int = 3): | ||
"""Elastic deformation of images as described in [Simard2003]_. | ||
.. [Simard2003] Simard, Steinkraus and Platt, "Best Practices for | ||
Convolutional Neural Networks applied to Visual Document Analysis", | ||
:param aug_input: input tensor of shape (C,Y,X) | ||
:param alpha: global pixel shifting (correlated to the article) | ||
:param aug_input: list of tensors of shape (C,Y,X) | ||
:param sigma: Gaussian filter parameter | ||
:param channels: which channels to apply the augmentation | ||
:param num_points: define the resolution of the deformation gris | ||
see https://github.com/gvtulder/elasticdeform for more info. | ||
:return distorted image | ||
""" | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why did you change it? the previous implementation didn't work? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I find it more straightforward and easy to use. |
||
random_state = numpy.random.RandomState(None) | ||
if channels is None: | ||
channels = list(range(aug_input.shape[0])) | ||
aug_tensor = aug_input.numpy() | ||
for channel in channels: | ||
aug_channel_tensor = aug_input[channel].numpy() | ||
shape = aug_channel_tensor.shape | ||
dx1 = gaussian_filter((random_state.rand(*shape) * 2 - 1), sigma, mode="constant", cval=0) * alpha | ||
dx2 = gaussian_filter((random_state.rand(*shape) * 2 - 1), sigma, mode="constant", cval=0) * alpha | ||
# convert back to torch tensor | ||
aug_input = [numpy.array(t) for t in aug_input] | ||
# for a (ch X Rows X cols) image - deform the 2 last axis | ||
axis = [(1,2) for _ in range(len(aug_input))] | ||
aug_input_d = ed.deform_random_grid(aug_input, sigma=sigma, points=num_points, axis=axis) | ||
|
||
x1, x2 = numpy.meshgrid(numpy.arange(shape[0]), numpy.arange(shape[1])) | ||
indices = numpy.reshape(x2 + dx2, (-1, 1)), numpy.reshape(x1 + dx1, (-1, 1)) | ||
aug_output = [torch.from_numpy(t) for t in aug_input_d] | ||
|
||
distored_image = map_coordinates(aug_channel_tensor, indices, order=1, mode='reflect') | ||
distored_image = distored_image.reshape(aug_channel_tensor.shape) | ||
aug_tensor[channel] = distored_image | ||
return torch.from_numpy(aug_tensor) | ||
return aug_output | ||
|
||
|
||
######### Default / Example augmentation pipline for a 2D image | ||
|
@@ -452,3 +455,56 @@ def aug_op_batch_mix_up(aug_input: Tuple[Tensor, Tensor], factor: float) -> Tupl | |
img = img * (1.0 - factor) + factor * img_mix_up | ||
labels = labels * (1.0 - factor) + factor * labels_mix_up | ||
return img, labels | ||
|
||
|
||
def aug_op_random_crop_and_resize(aug_input: Tensor, | ||
out_size: Union[int, Tuple[int, int], Tuple[int, int, int]], | ||
crop_size: float = 1.0, # or optional - Tuple[float, float] | ||
x_off: float = 1.0, | ||
y_off: float = 1.0, | ||
z_off: float = 1.0) -> Tensor: | ||
""" | ||
random crop a (3d) tensor and resize it to a given size | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Add here aug_input and which dimensions you expect |
||
:param crop_size: float <= 1.0 - the fraction to crop from the original tensor for each dim | ||
:param x_off: float <= 1.0 - the x-offset to take | ||
:param y_off: float <= 1.0 - the y-offset to take | ||
:param z_off: float <= 1.0 - the z-offset to take | ||
:param out_size: the size of the output tensor | ||
:return: the output tensor | ||
""" | ||
in_shape = aug_input.shape | ||
|
||
if len(aug_input.shape) == 4: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. else not supported error? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I added support for 2d tensor. Need your help on how to handle other dimensions (error?) |
||
ch, z, y, x = in_shape | ||
|
||
x_width = int(crop_size * x) | ||
x_off = int(x_off * (x - x_width)) | ||
|
||
y_width = int(crop_size * y) | ||
y_off = int(y_off * (y - y_width)) | ||
|
||
z_width = int(crop_size * z) | ||
z_off = int(z_off * (z - z_width)) | ||
|
||
aug_tensor = aug_input[:, z_off:z_off+z_width, y_off:y_off+y_width, x_off:x_off+x_width] | ||
|
||
aug_tensor = F.interpolate(aug_tensor, out_size) | ||
|
||
elif len(aug_input.shape) == 3: | ||
ch, y, x = in_shape | ||
|
||
x_width = int(crop_size * x) | ||
x_off = int(x_off * (x - x_width)) | ||
|
||
y_width = int(crop_size * y) | ||
y_off = int(y_off * (y - y_width)) | ||
|
||
aug_tensor = aug_input[:, y_off:y_off+y_width, x_off:x_off+x_width] | ||
|
||
aug_tensor = F.interpolate(aug_tensor, out_size) | ||
|
||
# else: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. else throw error? |
||
|
||
|
||
|
||
return aug_tensor |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -65,8 +65,8 @@ def __call__(self, predict, target): | |
predict = predict.contiguous().view(predict.shape[0], -1) | ||
target = target.contiguous().view(target.shape[0], -1) | ||
|
||
if target.dtype == torch.int64: | ||
target = target.type(torch.float32).to(target.device) | ||
if target.dtype == torch.int64 or target.dtype == torch.int32: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why do we care about the type? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I also think that the to device is not necessary here - can you remove it> There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I remember I had to add this option. not sure what was the case. |
||
target = target.type(torch.float32) | ||
num = 2*torch.sum(torch.mul(predict, target), dim=1) + self.eps | ||
den = torch.sum(predict.pow(self.p) + target.pow(self.p), dim=1) + self.eps | ||
loss = 1 - num / den | ||
|
@@ -82,10 +82,84 @@ def __call__(self, predict, target): | |
raise Exception('Unexpected reduction {}'.format(self.reduction)) | ||
|
||
|
||
class DiceBCELoss(FuseLossBase): | ||
|
||
def __init__(self, | ||
pred_name: str = None, | ||
target_name: str = None, | ||
filter_func: Optional[Callable]=None, | ||
class_weights=None, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. type annotation |
||
bce_weight: float=1.0, | ||
power: int=1, | ||
eps: float=1., | ||
reduction: str='mean'): | ||
''' | ||
Compute a weighted sum of dice-loss and cross entropy loss. | ||
|
||
:param pred_name: batch_dict key for predicted output (e.g., class probabilities after softmax). | ||
Expected Tensor shape = [batch, num_classes, height, width] | ||
:param target_name: batch_dict key for target (e.g., ground truth label). Expected Tensor shape = [batch, height, width] | ||
:param filter_func: function that filters batch_dict/ The function gets ans input batch_dict and returns filtered batch_dict | ||
:param class_weights: An array of shape [num_classes,] | ||
:param bce_weight: weight to attach to the bce loss, default : 1.0 | ||
:param power: Denominator value: \sum{x^p} + \sum{y^p}, default: 1 | ||
:param eps: A float number to smooth loss, and avoid NaN error, default: 1 | ||
:param reduction: Reduction method to apply, return mean over batch if 'mean', | ||
return sum if 'sum', return a tensor of shape [N,] if 'none' | ||
|
||
Returns: Loss tensor according to arg reduction | ||
Raise: Exception if unexpected reduction | ||
''' | ||
|
||
super().__init__(pred_name, target_name, 1.0) | ||
self.class_weights = class_weights | ||
self.bce_weight = bce_weight | ||
self.filter_func = filter_func | ||
self.dice = BinaryDiceLoss(power, eps, reduction) | ||
|
||
def __call__(self, batch_dict): | ||
|
||
if self.filter_func is not None: | ||
batch_dict = self.filter_func(batch_dict) | ||
predict = FuseUtilsHierarchicalDict.get(batch_dict, self.pred_name).float() | ||
target = FuseUtilsHierarchicalDict.get(batch_dict, self.target_name).long() | ||
|
||
target = target.type(torch.float32) | ||
|
||
total_loss = 0 | ||
n_classes = predict.shape[1] | ||
|
||
# Convert target to one hot encoding | ||
if n_classes > 1 and target.shape[1] != n_classes: | ||
target = make_one_hot(target, n_classes) | ||
|
||
assert predict.shape == target.shape, 'predict & target shape do not match' | ||
|
||
total_class_weights = sum(self.class_weights) if self.class_weights is not None else n_classes | ||
for cls_index in range(n_classes): | ||
dice_loss = self.dice(predict[:, cls_index, :, :], target[:, cls_index, :, :]) | ||
if self.bce_weight > 0.0: | ||
bce_loss = F.binary_cross_entropy(predict[:, cls_index, :, :].view(-1), | ||
target[:, cls_index, :, :].view(-1), | ||
reduction='mean') | ||
dice_loss += self.bce_weight * bce_loss | ||
|
||
if self.class_weights is not None: | ||
assert self.class_weights.shape[0] == n_classes, \ | ||
'Expect weight shape [{}], got[{}]'.format(n_classes, self.class_weights.shape[0]) | ||
dice_loss *= self.class_weights[cls_index] | ||
|
||
total_loss += dice_loss | ||
total_loss /= total_class_weights | ||
|
||
return self.weight*total_loss | ||
|
||
|
||
class FuseDiceLoss(FuseLossBase): | ||
|
||
def __init__(self, pred_name, | ||
target_name, | ||
def __init__(self, | ||
pred_name: str = None, | ||
target_name: str = None, | ||
filter_func: Optional[Callable] = None, | ||
class_weights=None, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. type annotation |
||
ignore_cls_index_list=None, | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
# SIIM-ACR Pneumothorax Segmentation with Fute | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I guess that you plan another commit to add the information |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,133 @@ | ||
import pydicom | ||
from pathlib import Path | ||
import pandas as pd | ||
from tqdm import tqdm as progress_bar | ||
import PIL | ||
import numpy as np | ||
import matplotlib.pylab as plt | ||
|
||
|
||
""" | ||
download dataset from - | ||
https://www.kaggle.com/seesee/siim-train-test | ||
|
||
The path to the extracted data should be updated in the <dataset_path> variable. | ||
The output images will be stored at <main_out_path>. | ||
the output size is defined by <out_size_list> (the output is created with a folder for each size) | ||
""" | ||
########################################## | ||
# Params | ||
########################################## | ||
main_out_path = '../siim_data' | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. You should get it from ROOT_DATA, right? |
||
dataset_path = '../siim/' | ||
out_size_list = [256, 512] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. function argument? |
||
|
||
|
||
def rle2mask(rles, width, height): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. missing type annotations and documentation in this file |
||
""" | ||
|
||
rle encoding if images | ||
input: rles(list of rle), width and height of image | ||
returns: mask of shape (width,height) | ||
""" | ||
|
||
mask= np.zeros(width* height) | ||
for rle in rles: | ||
array = np.asarray([int(x) for x in rle.split()]) | ||
starts = array[0::2] | ||
lengths = array[1::2] | ||
|
||
current_position = 0 | ||
for index, start in enumerate(starts): | ||
current_position += start | ||
mask[current_position:current_position+lengths[index]] = 255 | ||
current_position += lengths[index] | ||
|
||
return mask.reshape(width, height).T | ||
|
||
|
||
def filter_files(files, include=[], exclude=[]): | ||
for incl in include: | ||
files = [f for f in files if incl in f.name] | ||
for excl in exclude: | ||
files = [f for f in files if excl not in f.name] | ||
return sorted(files) | ||
|
||
|
||
def ls(x, recursive=False, include=[], exclude=[]): | ||
if not recursive: | ||
out = list(x.iterdir()) | ||
else: | ||
out = [o for o in x.glob('**/*')] | ||
out = filter_files(out, include=include, exclude=exclude) | ||
return out | ||
|
||
|
||
Path.ls = ls | ||
|
||
|
||
class InOutPath(): | ||
def __init__(self, input_path:Path, output_path:Path): | ||
if isinstance(input_path, str): input_path = Path(input_path) | ||
if isinstance(output_path, str): output_path = Path(output_path) | ||
self.inp = input_path | ||
self.out = output_path | ||
self.mkoutdir() | ||
|
||
def mkoutdir(self): | ||
self.out.mkdir(exist_ok=True, parents=True) | ||
|
||
def __repr__(self): | ||
return '\n'.join([f'{i}: {o}' for i, o in self.__dict__.items()]) + '\n' | ||
|
||
|
||
def dcm2png(SZ, dataset): | ||
path = InOutPath(Path(dataset_path + f'/dicom-images-{dataset}'), Path(main_out_path + f'/data{SZ}/{dataset}')) | ||
files = path.inp.ls(recursive=True, include=['.dcm']) | ||
for f in progress_bar(files): | ||
dcm = pydicom.read_file(str(f)).pixel_array | ||
im = PIL.Image.fromarray(dcm).resize((SZ,SZ)) | ||
im.save(path.out/f'{f.stem}.png') | ||
|
||
|
||
def masks2png(SZ): | ||
path = InOutPath(Path('data'), Path(main_out_path + f'/data{SZ}/masks')) | ||
for i in progress_bar(list(set(rle_df.ImageId.values))): | ||
I = rle_df.ImageId == i | ||
name = rle_df.loc[I, 'ImageId'] | ||
enc = rle_df.loc[I, ' EncodedPixels'] | ||
if sum(I) == 1: | ||
enc = enc.values[0] | ||
name = name.values[0] | ||
if enc == '-1': # ' -1': | ||
m = np.zeros((1024, 1024)).astype(np.uint8) | ||
else: | ||
m = rle2mask([enc], 1024, 1024).astype(np.uint8) | ||
PIL.Image.fromarray(m).resize((SZ,SZ)).save(f'{path.out}/{name}.png') | ||
else: | ||
m = rle2mask(enc.values, 1024, 1024).astype(np.uint8) | ||
PIL.Image.fromarray(m).resize((SZ,SZ)).save(f'{path.out}/{name.values[0]}.png') | ||
|
||
|
||
|
||
if __name__ == '__main__': | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Do you require to call this script? |
||
rle_df = pd.read_csv(dataset_path + '/train-rle.csv') | ||
|
||
for SZ in progress_bar(out_size_list): | ||
print(f'Converting data for train{SZ}') | ||
dcm2png(SZ, 'train') | ||
print(f'Converting data for test{SZ}') | ||
dcm2png(SZ, 'test') | ||
print(f'Generating masks for size {SZ}') | ||
masks2png(SZ) | ||
|
||
for SZ in progress_bar(out_size_list): | ||
# Missing masks set to 0 | ||
print('Generating missing masks as zeros') | ||
train_images = [o.name for o in Path(main_out_path + f'/data{SZ}/train').ls(include=['.png'])] | ||
train_masks = [o.name for o in Path(main_out_path + f'/data{SZ}/masks').ls(include=['.png'])] | ||
missing_masks = set(train_images) - set(train_masks) | ||
path = InOutPath(Path('data'), Path(main_out_path + f'/data{SZ}/masks')) | ||
for name in progress_bar(missing_masks): | ||
m = np.zeros((1024, 1024)).astype(np.uint8).T | ||
PIL.Image.fromarray(m).resize((SZ,SZ)).save(main_out_path + f'/data{SZ}/masks/{name}') |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's a new dependency - add it to requirements.txt
If there are more, please add them too.