Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Seg example #42

Open
wants to merge 53 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
Show all changes
53 commits
Select commit Hold shift + click to select a range
70aaa31
Segmentation example - 1st commit
Mar 15, 2022
c8674b0
fix elastic augmentation and loss bce+dice; update for new eval packa…
Mar 15, 2022
98ebeee
Change to eval package
Mar 20, 2022
62d3c1c
Merge branch 'master' into seg_example
mosheraboh Mar 25, 2022
26c1ffe
move example to a new folder
Mar 26, 2022
59afc55
Update the create_data script + add some comment regarding the origin…
Mar 26, 2022
310726e
remove old script and update main script according to PR comments (no…
Mar 26, 2022
76f834d
change names to eval*
Mar 26, 2022
13aa1da
Changes following the comments on the PR
Mar 27, 2022
4ad76f2
remove commented code
Apr 4, 2022
323c270
Merge branch 'master' into seg_example
Apr 4, 2022
960a484
Merge branch 'master' into seg_example
mosheraboh Apr 7, 2022
5a4a6cf
change input desc to file names and processor to compute mask images …
Apr 10, 2022
e1e601d
factor out end to end examples to seprate package
Apr 13, 2022
75ce22b
add examples to PYTHONPATH
Apr 13, 2022
c53c96c
run unittests in examples
Apr 13, 2022
3551533
remove fuse1 data package
Apr 14, 2022
97eac6b
remove dataset from manager
Apr 14, 2022
dddac6d
convert mnist to fuse2 style
Apr 14, 2022
83936d0
create dl package
Apr 17, 2022
69a3b09
remove Fuse prefix
Apr 17, 2022
5771c85
reorg examples
Apr 17, 2022
13c5f0b
Merge branch 'master' of github.com:IBM/fuse-med-ml into fuse2
Apr 17, 2022
c12ff08
Merge branch 'fuse2' of github.com:IBM/fuse-med-ml into data_package
Apr 17, 2022
3e3a476
add fuse data package
Apr 17, 2022
5f614a0
Merge branch 'data_package' of github.com:IBM/fuse-med-ml into mnist_…
Apr 17, 2022
9b511cc
adjust mnist runner
Apr 17, 2022
532b347
imaging extension
Apr 18, 2022
1a94187
Merge branch 'data_package' of github.com:IBM/fuse-med-ml into mnist_…
Apr 18, 2022
92d78d1
Move changes from master's branch to mnist_fuse2_style's branch
Apr 18, 2022
0b8b573
Fixed import path
Apr 18, 2022
fb72386
remove the create-data script and move all its functionality to input…
Apr 19, 2022
bab4bf7
Updated the notebook (mnist example) to fuse2
Apr 25, 2022
d4a5c02
Skip test - temp
Apr 25, 2022
68df39c
Update test_notebook_hello_world.py
SagiPolaczek Apr 25, 2022
ae69698
Data package (#61)
mosheraboh Apr 28, 2022
b4804c3
skip test (it works locally)
Apr 28, 2022
63f2a29
Move changes from master's branch to mnist_fuse2_style's branch
Apr 18, 2022
c40f129
Fixed import path
Apr 18, 2022
419a36b
Updated the notebook (mnist example) to fuse2
Apr 25, 2022
13536f0
Skip test - temp
Apr 25, 2022
1077b8c
Update test_notebook_hello_world.py
SagiPolaczek Apr 25, 2022
85cdae4
skip test (it works locally)
Apr 28, 2022
26f78f8
Merge branch 'hello_world_unittest' of github.com:IBM/fuse-med-ml int…
Apr 28, 2022
0f3abf3
Fixed override in the set_device functionality and made cpu usage mor…
Apr 28, 2022
5d0b65a
Merge pull request #63 from IBM/hello_world_unittest
SagiPolaczek Apr 28, 2022
3f71935
data package readme
mosheraboh May 2, 2022
17ef080
Merge with master
May 3, 2022
1dc5050
merged with fuse2
May 3, 2022
2e1f98c
Fix import for fuse2 + add a static pipeline + change data source int…
May 10, 2022
dbe44a5
Complete data pipeline including the dynamic part
May 10, 2022
6c4232c
Working fuse2 version + fix to gaussian op data type
May 17, 2022
0ef6742
remove comments and non-required input processor file
May 17, 2022
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 30 additions & 9 deletions fuse/data/augmentor/augmentor_toolbox.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
"""

from copy import deepcopy
from typing import Tuple, Any, List, Iterable, Optional
from typing import Tuple, Any, List, Iterable, Optional, Union

import numpy
import torch
Expand Down Expand Up @@ -264,19 +264,23 @@ def aug_op_gaussian(aug_input: Tensor, mean: float = 0.0, std: float = 0.03, cha
return aug_tensor


def aug_op_elastic_transform(aug_input: Tensor, alpha: float = 1, sigma: float = 50, channels: Optional[List[int]] = None):
def aug_op_elastic_transform(aug_input: Tuple[Tensor],
sigma: float = 50,
num_points: int = 3):
"""Elastic deformation of images as described in [Simard2003]_.
.. [Simard2003] Simard, Steinkraus and Platt, "Best Practices for
Convolutional Neural Networks applied to Visual Document Analysis",
:param aug_input: input tensor of shape (C,Y,X)
:param alpha: global pixel shifting (correlated to the article)
:param aug_input: list of tensors of shape (C,Y,X)
:param sigma: Gaussian filter parameter
:param channels: which channels to apply the augmentation
:param num_points: define the resolution of the deformation gris
see https://github.com/gvtulder/elasticdeform for more info.
:return distorted image
"""
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why did you change it? the previous implementation didn't work?
I see that you are not using channels - can you support it as well?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I find it more straightforward and easy to use.

# convert back to torch tensor
aug_input = [numpy.array(t) for t in aug_input]
aug_input_d = ed.deform_random_grid(aug_input, sigma=7, points=3, axis=[(1, 2), (1,2)])
# for a (ch X Rows X cols) image - deform the 2 last axis
axis = [(1,2) for _ in range(len(aug_input))]
aug_input_d = ed.deform_random_grid(aug_input, sigma=sigma, points=num_points, axis=axis)

aug_output = [torch.from_numpy(t) for t in aug_input_d]

Expand Down Expand Up @@ -454,7 +458,7 @@ def aug_op_batch_mix_up(aug_input: Tuple[Tensor, Tensor], factor: float) -> Tupl


def aug_op_random_crop_and_resize(aug_input: Tensor,
out_size,
out_size: Union[int, Tuple[int, int], Tuple[int, int, int]],
crop_size: float = 1.0, # or optional - Tuple[float, float]
x_off: float = 1.0,
y_off: float = 1.0,
Expand All @@ -463,8 +467,8 @@ def aug_op_random_crop_and_resize(aug_input: Tensor,
random crop a (3d) tensor and resize it to a given size
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add here aug_input and which dimensions you expect

:param crop_size: float <= 1.0 - the fraction to crop from the original tensor for each dim
:param x_off: float <= 1.0 - the x-offset to take
:param y_off:
:param z_off:
:param y_off: float <= 1.0 - the y-offset to take
:param z_off: float <= 1.0 - the z-offset to take
:param out_size: the size of the output tensor
:return: the output tensor
"""
Expand All @@ -486,4 +490,21 @@ def aug_op_random_crop_and_resize(aug_input: Tensor,

aug_tensor = F.interpolate(aug_tensor, out_size)

elif len(aug_input.shape) == 3:
ch, y, x = in_shape

x_width = int(crop_size * x)
x_off = int(x_off * (x - x_width))

y_width = int(crop_size * y)
y_off = int(y_off * (y - y_width))

aug_tensor = aug_input[:, y_off:y_off+y_width, x_off:x_off+x_width]

aug_tensor = F.interpolate(aug_tensor, out_size)

# else:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

else throw error?




return aug_tensor
13 changes: 7 additions & 6 deletions fuse/losses/segmentation/loss_dice.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ def __call__(self, predict, target):
target = target.contiguous().view(target.shape[0], -1)

if target.dtype == torch.int64 or target.dtype == torch.int32:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we care about the type?
Can we convert it anyway?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I also think that the to device is not necessary here - can you remove it>

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I remember I had to add this option. not sure what was the case.
I removed both the to(device) .

target = target.type(torch.float32).to(target.device)
target = target.type(torch.float32)
num = 2*torch.sum(torch.mul(predict, target), dim=1) + self.eps
den = torch.sum(predict.pow(self.p) + target.pow(self.p), dim=1) + self.eps
loss = 1 - num / den
Expand All @@ -85,8 +85,8 @@ def __call__(self, predict, target):
class DiceBCELoss(FuseLossBase):

def __init__(self,
pred_name,
target_name,
pred_name: str = None,
target_name: str = None,
filter_func: Optional[Callable]=None,
class_weights=None,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

type annotation

bce_weight: float=1.0,
Expand Down Expand Up @@ -124,7 +124,7 @@ def __call__(self, batch_dict):
predict = FuseUtilsHierarchicalDict.get(batch_dict, self.pred_name).float()
target = FuseUtilsHierarchicalDict.get(batch_dict, self.target_name).long()

target = target.type(torch.float32).to(target.device)
target = target.type(torch.float32)

total_loss = 0
n_classes = predict.shape[1]
Expand Down Expand Up @@ -157,8 +157,9 @@ def __call__(self, batch_dict):

class FuseDiceLoss(FuseLossBase):

def __init__(self, pred_name,
target_name,
def __init__(self,
pred_name: str = None,
target_name: str = None,
filter_func: Optional[Callable] = None,
class_weights=None,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

type annotation

ignore_cls_index_list=None,
Expand Down
1 change: 1 addition & 0 deletions fuse_examples/segmentation/siim/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
# SIIM-ACR Pneumothorax Segmentation with Fute
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess that you plan another commit to add the information

25 changes: 0 additions & 25 deletions fuse_examples/segmentation/siim/data_source_segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,41 +69,16 @@ def __init__(self,
repartition = pickle.load(splits)
sample_descs = repartition['val']
else:
# TODO - this option is not clear - if the partition file is not give? do we train
# with all the data? or just dont save the partition? (than we will not be able
# to re-run the experiment ...
for sample_id in input_df.iloc[:, 0]:
sample_descs.append(sample_id)

self.samples = sample_descs

self.input_source = [image_source, mask_source]

# prev version
# self.samples = input_source

# @staticmethod
# def filter_by_conditions(samples: pd.DataFrame, conditions: Optional[List[Dict[str, List]]]):
# """
# Returns a vector of the samples that passed the conditions
# :param samples: dataframe to check. expected to have at least sample_desc column.
# :param conditions: list of dictionaries. each dictionary has column name as keys and possible values as the values.
# for each dict in the list:
# the keys are applied with AND between them.
# the dict conditions are applied with OR between them.
# :return: boolean vector with the filtered samples
# """
# to_keep = samples.sample_desc.isna() # start with all false
# for condition_list in conditions:
# condition_to_keep = samples.sample_desc.notna() # start with all true
# for column, values in condition_list.items():
# condition_to_keep = condition_to_keep & samples[column].isin(values) # all conditions in list must be met
# to_keep = to_keep | condition_to_keep # add this condition samples to_keep
# return to_keep

def get_samples_description(self):
return self.samples
# return list(self.samples_df['sample_desc'])

def summary(self) -> str:
summary_str = ''
Expand Down
72 changes: 27 additions & 45 deletions fuse_examples/segmentation/siim/runner_seg.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,22 +15,6 @@

Created on June 30, 2021

"""

"""

(C) Copyright 2021 IBM Corp.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
Created on June 30, 2021

"""
import os
import logging
Expand Down Expand Up @@ -105,7 +89,7 @@
'force_reset_model_dir': True, # If True will reset model dir automatically - otherwise will prompt 'are you sure' message.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's not assume anything about ROOT and DATA_ROOT.
Please document what each one is, and ask the user to fill it in

'cache_dir': os.path.join(CACHE_PATH, EXPERIMENT_CACHE+'_cache_dir'),
'inference_dir': os.path.join(ROOT, EXPERIMENT, 'infer_dir'),
'analyze_dir': os.path.join(ROOT, EXPERIMENT, 'analyze_dir')}
'eval_dir': os.path.join(ROOT, EXPERIMENT, 'eval_dir')}

##########################################
# Train Common Params
Expand All @@ -130,7 +114,8 @@
[
('data.input.input_0','data.gt.gt_global'),
aug_op_elastic_transform,
{},
{'sigma': 7,
'num_points': 3},
{'apply': RandBool(0.7)}
],
[
Expand Down Expand Up @@ -180,9 +165,6 @@ def run_train(paths: dict, train_common_params: dict):
fuse_logger_start(output_path=paths['model_dir'], console_verbose_level=logging.INFO)
lgr = logging.getLogger('Fuse')

# Download data
# TODO - function to download + arrange the data

lgr.info('\nFuse Train', {'attrs': ['bold', 'underline']})

lgr.info(f'model_dir={paths["model_dir"]}', {'color': 'magenta'})
Expand Down Expand Up @@ -214,7 +196,7 @@ def run_train(paths: dict, train_common_params: dict):
# Create visualizer (optional)
visualiser = FuseVisualizerDefault(image_name='data.input.input_0',
mask_name='data.gt.gt_global',
pred_name='model.logits.classification')
pred_name='model.logits.segmentation')

train_dataset = FuseDatasetDefault(cache_dest=paths['cache_dir'],
data_source=train_data_source,
Expand Down Expand Up @@ -271,19 +253,19 @@ def run_train(paths: dict, train_common_params: dict):

model = FuseModelWrapper(model=torch_model,
model_inputs=['data.input.input_0'],
model_outputs=['logits.classification']
model_outputs=['logits.segmentation']
)

lgr.info('Model: Done', {'attrs': 'bold'})
# ====================================================================================
# Loss
# ====================================================================================
losses = {
'dice_loss': DiceBCELoss(pred_name='model.logits.classification',
'dice_loss': DiceBCELoss(pred_name='model.logits.segmentation',
target_name='data.gt.gt_global')
}

model = model.cuda()
# model = model.cuda()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove

optimizer = optim.SGD(model.parameters(),
lr=train_common_params['manager.learning_rate'],
momentum=0.9,
Expand Down Expand Up @@ -363,7 +345,7 @@ def run_infer(paths: dict, infer_common_params: dict):
# Create visualizer (optional)
visualiser = FuseVisualizerDefault(image_name='data.input.input_0',
mask_name='data.gt.gt_global',
pred_name='model.logits.classification')
pred_name='model.logits.segmentation')

infer_dataset = FuseDatasetDefault(cache_dest=paths['cache_dir'],
data_source=infer_data_source,
Expand All @@ -387,8 +369,8 @@ def run_infer(paths: dict, infer_common_params: dict):

#### Manager for inference
manager = FuseManagerDefault()
# extract just the global classification per sample and save to a file
output_columns = ['model.logits.classification', 'data.gt.gt_global']
# extract just the global segmentation per sample and save to a file
output_columns = ['model.logits.segmentation', 'data.gt.gt_global']
manager.infer(data_loader=infer_dataloader,
input_model_dir=paths['model_dir'],
checkpoint=infer_common_params['checkpoint'],
Expand All @@ -398,7 +380,7 @@ def run_infer(paths: dict, infer_common_params: dict):
# visualize the predictions
infer_processor = FuseProcessorDataFrame(data_pickle_filename=infer_common_params['infer_filename'])
descriptors_list = infer_processor.get_samples_descriptors()
out_name = 'model.logits.classification'
out_name = 'model.logits.segmentation'
gt_name = 'data.gt.gt_global'
for desc in descriptors_list[:10]:
data = infer_processor(desc)
Expand All @@ -413,32 +395,32 @@ def run_infer(paths: dict, infer_common_params: dict):
plt.savefig(fn)

######################################
# Analyze Common Params
# Evaluation Common Params
######################################
ANALYZE_COMMON_PARAMS = {}
ANALYZE_COMMON_PARAMS['infer_filename'] = INFER_COMMON_PARAMS['infer_filename']
ANALYZE_COMMON_PARAMS['output_filename'] = os.path.join(PATHS['analyze_dir'], 'all_metrics.txt')
ANALYZE_COMMON_PARAMS['num_workers'] = 4
ANALYZE_COMMON_PARAMS['batch_size'] = 8
EVAL_COMMON_PARAMS = {}
EVAL_COMMON_PARAMS['infer_filename'] = INFER_COMMON_PARAMS['infer_filename']
EVAL_COMMON_PARAMS['output_filename'] = os.path.join(PATHS['eval_dir'], 'all_metrics.txt')
EVAL_COMMON_PARAMS['num_workers'] = 4
EVAL_COMMON_PARAMS['batch_size'] = 8

######################################
# Analyze Template
######################################
def run_analyze(paths: dict, analyze_common_params: dict):
def run_eval(paths: dict, eval_common_params: dict):
fuse_logger_start(output_path=None, console_verbose_level=logging.INFO)
lgr = logging.getLogger('Fuse')
lgr.info('Fuse Analyze', {'attrs': ['bold', 'underline']})
lgr.info('Fuse eval', {'attrs': ['bold', 'underline']})

# define iterator
def data_iter():
data = pd.read_pickle(analyze_common_params['infer_filename'])
data = pd.read_pickle(eval_common_params['infer_filename'])
n_samples = data.shape[0]
threshold = 1e-7 #0.5
for inx in range(n_samples):
row = data.loc[inx]
sample_dict = {}
sample_dict["id"] = row['id']
sample_dict["pred.array"] = row['model.logits.classification'] > threshold
sample_dict["pred.array"] = row['model.logits.segmentation'] > threshold
sample_dict["label.array"] = row['data.gt.gt_global']
yield sample_dict

Expand All @@ -449,7 +431,7 @@ def data_iter():
("PixelAcc", MetricPixelAccuracy(pred='pred.array', target='label.array')),
])

# create analyzer
# create evaluator
evaluator = EvaluatorDefault()

results = evaluator.eval(ids=None,
Expand All @@ -470,8 +452,8 @@ def data_iter():
force_gpus = None # [0]
FuseUtilsGPU.choose_and_enable_multiple_gpus(NUM_GPUS, force_gpus=force_gpus)

RUNNING_MODES = ['train', 'infer', 'analyze'] # Options: 'train', 'infer', 'analyze'
# RUNNING_MODES = ['analyze'] # Options: 'train', 'infer', 'analyze'
RUNNING_MODES = ['train', 'infer', 'eval'] # Options: 'train', 'infer', 'eval'
# RUNNING_MODES = ['eval'] # Options: 'train', 'infer', 'eval'
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove


# train
if 'train' in RUNNING_MODES:
Expand All @@ -481,7 +463,7 @@ def data_iter():
if 'infer' in RUNNING_MODES:
run_infer(paths=PATHS, infer_common_params=INFER_COMMON_PARAMS)

# analyze
if 'analyze' in RUNNING_MODES:
run_analyze(paths=PATHS, analyze_common_params=ANALYZE_COMMON_PARAMS)
# eval
if 'eval' in RUNNING_MODES:
run_eval(paths=PATHS, eval_common_params=EVAL_COMMON_PARAMS)

Loading