-
Notifications
You must be signed in to change notification settings - Fork 35
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Seg example #42
base: master
Are you sure you want to change the base?
Seg example #42
Changes from 2 commits
70aaa31
c8674b0
98ebeee
62d3c1c
26c1ffe
59afc55
310726e
76f834d
13aa1da
4ad76f2
323c270
960a484
5a4a6cf
e1e601d
75ce22b
c53c96c
3551533
97eac6b
dddac6d
83936d0
69a3b09
5771c85
13c5f0b
c12ff08
3e3a476
5f614a0
9b511cc
532b347
1a94187
92d78d1
0b8b573
fb72386
bab4bf7
d4a5c02
68df39c
ae69698
b4804c3
63f2a29
c40f129
419a36b
13536f0
1077b8c
85cdae4
26f78f8
0f3abf3
5d0b65a
3f71935
17ef080
1dc5050
2e1f98c
dbe44a5
6c4232c
0ef6742
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -18,7 +18,7 @@ | |
""" | ||
|
||
from copy import deepcopy | ||
from typing import Tuple, Any, List, Iterable, Optional | ||
from typing import Tuple, Any, List, Iterable, Optional, Union | ||
|
||
import numpy | ||
import torch | ||
|
@@ -264,19 +264,23 @@ def aug_op_gaussian(aug_input: Tensor, mean: float = 0.0, std: float = 0.03, cha | |
return aug_tensor | ||
|
||
|
||
def aug_op_elastic_transform(aug_input: Tensor, alpha: float = 1, sigma: float = 50, channels: Optional[List[int]] = None): | ||
def aug_op_elastic_transform(aug_input: Tuple[Tensor], | ||
sigma: float = 50, | ||
num_points: int = 3): | ||
"""Elastic deformation of images as described in [Simard2003]_. | ||
.. [Simard2003] Simard, Steinkraus and Platt, "Best Practices for | ||
Convolutional Neural Networks applied to Visual Document Analysis", | ||
:param aug_input: input tensor of shape (C,Y,X) | ||
:param alpha: global pixel shifting (correlated to the article) | ||
:param aug_input: list of tensors of shape (C,Y,X) | ||
:param sigma: Gaussian filter parameter | ||
:param channels: which channels to apply the augmentation | ||
:param num_points: define the resolution of the deformation gris | ||
see https://github.com/gvtulder/elasticdeform for more info. | ||
:return distorted image | ||
""" | ||
# convert back to torch tensor | ||
aug_input = [numpy.array(t) for t in aug_input] | ||
aug_input_d = ed.deform_random_grid(aug_input, sigma=7, points=3, axis=[(1, 2), (1,2)]) | ||
# for a (ch X Rows X cols) image - deform the 2 last axis | ||
axis = [(1,2) for _ in range(len(aug_input))] | ||
aug_input_d = ed.deform_random_grid(aug_input, sigma=sigma, points=num_points, axis=axis) | ||
|
||
aug_output = [torch.from_numpy(t) for t in aug_input_d] | ||
|
||
|
@@ -454,7 +458,7 @@ def aug_op_batch_mix_up(aug_input: Tuple[Tensor, Tensor], factor: float) -> Tupl | |
|
||
|
||
def aug_op_random_crop_and_resize(aug_input: Tensor, | ||
out_size, | ||
out_size: Union[int, Tuple[int, int], Tuple[int, int, int]], | ||
crop_size: float = 1.0, # or optional - Tuple[float, float] | ||
x_off: float = 1.0, | ||
y_off: float = 1.0, | ||
|
@@ -463,8 +467,8 @@ def aug_op_random_crop_and_resize(aug_input: Tensor, | |
random crop a (3d) tensor and resize it to a given size | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Add here aug_input and which dimensions you expect |
||
:param crop_size: float <= 1.0 - the fraction to crop from the original tensor for each dim | ||
:param x_off: float <= 1.0 - the x-offset to take | ||
:param y_off: | ||
:param z_off: | ||
:param y_off: float <= 1.0 - the y-offset to take | ||
:param z_off: float <= 1.0 - the z-offset to take | ||
:param out_size: the size of the output tensor | ||
:return: the output tensor | ||
""" | ||
|
@@ -486,4 +490,21 @@ def aug_op_random_crop_and_resize(aug_input: Tensor, | |
|
||
aug_tensor = F.interpolate(aug_tensor, out_size) | ||
|
||
elif len(aug_input.shape) == 3: | ||
ch, y, x = in_shape | ||
|
||
x_width = int(crop_size * x) | ||
x_off = int(x_off * (x - x_width)) | ||
|
||
y_width = int(crop_size * y) | ||
y_off = int(y_off * (y - y_width)) | ||
|
||
aug_tensor = aug_input[:, y_off:y_off+y_width, x_off:x_off+x_width] | ||
|
||
aug_tensor = F.interpolate(aug_tensor, out_size) | ||
|
||
# else: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. else throw error? |
||
|
||
|
||
|
||
return aug_tensor |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -66,7 +66,7 @@ def __call__(self, predict, target): | |
target = target.contiguous().view(target.shape[0], -1) | ||
|
||
if target.dtype == torch.int64 or target.dtype == torch.int32: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why do we care about the type? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I also think that the to device is not necessary here - can you remove it> There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I remember I had to add this option. not sure what was the case. |
||
target = target.type(torch.float32).to(target.device) | ||
target = target.type(torch.float32) | ||
num = 2*torch.sum(torch.mul(predict, target), dim=1) + self.eps | ||
den = torch.sum(predict.pow(self.p) + target.pow(self.p), dim=1) + self.eps | ||
loss = 1 - num / den | ||
|
@@ -85,8 +85,8 @@ def __call__(self, predict, target): | |
class DiceBCELoss(FuseLossBase): | ||
|
||
def __init__(self, | ||
pred_name, | ||
target_name, | ||
pred_name: str = None, | ||
target_name: str = None, | ||
filter_func: Optional[Callable]=None, | ||
class_weights=None, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. type annotation |
||
bce_weight: float=1.0, | ||
|
@@ -124,7 +124,7 @@ def __call__(self, batch_dict): | |
predict = FuseUtilsHierarchicalDict.get(batch_dict, self.pred_name).float() | ||
target = FuseUtilsHierarchicalDict.get(batch_dict, self.target_name).long() | ||
|
||
target = target.type(torch.float32).to(target.device) | ||
target = target.type(torch.float32) | ||
|
||
total_loss = 0 | ||
n_classes = predict.shape[1] | ||
|
@@ -157,8 +157,9 @@ def __call__(self, batch_dict): | |
|
||
class FuseDiceLoss(FuseLossBase): | ||
|
||
def __init__(self, pred_name, | ||
target_name, | ||
def __init__(self, | ||
pred_name: str = None, | ||
target_name: str = None, | ||
filter_func: Optional[Callable] = None, | ||
class_weights=None, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. type annotation |
||
ignore_cls_index_list=None, | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
# SIIM-ACR Pneumothorax Segmentation with Fute | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I guess that you plan another commit to add the information |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -15,22 +15,6 @@ | |
|
||
Created on June 30, 2021 | ||
|
||
""" | ||
|
||
""" | ||
|
||
(C) Copyright 2021 IBM Corp. | ||
Licensed under the Apache License, Version 2.0 (the "License"); | ||
you may not use this file except in compliance with the License. | ||
You may obtain a copy of the License at | ||
http://www.apache.org/licenses/LICENSE-2.0 | ||
Unless required by applicable law or agreed to in writing, software | ||
distributed under the License is distributed on an "AS IS" BASIS, | ||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
See the License for the specific language governing permissions and | ||
limitations under the License. | ||
Created on June 30, 2021 | ||
|
||
""" | ||
import os | ||
import logging | ||
|
@@ -105,7 +89,7 @@ | |
'force_reset_model_dir': True, # If True will reset model dir automatically - otherwise will prompt 'are you sure' message. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Let's not assume anything about ROOT and DATA_ROOT. |
||
'cache_dir': os.path.join(CACHE_PATH, EXPERIMENT_CACHE+'_cache_dir'), | ||
'inference_dir': os.path.join(ROOT, EXPERIMENT, 'infer_dir'), | ||
'analyze_dir': os.path.join(ROOT, EXPERIMENT, 'analyze_dir')} | ||
'eval_dir': os.path.join(ROOT, EXPERIMENT, 'eval_dir')} | ||
|
||
########################################## | ||
# Train Common Params | ||
|
@@ -130,7 +114,8 @@ | |
[ | ||
('data.input.input_0','data.gt.gt_global'), | ||
aug_op_elastic_transform, | ||
{}, | ||
{'sigma': 7, | ||
'num_points': 3}, | ||
{'apply': RandBool(0.7)} | ||
], | ||
[ | ||
|
@@ -180,9 +165,6 @@ def run_train(paths: dict, train_common_params: dict): | |
fuse_logger_start(output_path=paths['model_dir'], console_verbose_level=logging.INFO) | ||
lgr = logging.getLogger('Fuse') | ||
|
||
# Download data | ||
# TODO - function to download + arrange the data | ||
|
||
lgr.info('\nFuse Train', {'attrs': ['bold', 'underline']}) | ||
|
||
lgr.info(f'model_dir={paths["model_dir"]}', {'color': 'magenta'}) | ||
|
@@ -214,7 +196,7 @@ def run_train(paths: dict, train_common_params: dict): | |
# Create visualizer (optional) | ||
visualiser = FuseVisualizerDefault(image_name='data.input.input_0', | ||
mask_name='data.gt.gt_global', | ||
pred_name='model.logits.classification') | ||
pred_name='model.logits.segmentation') | ||
|
||
train_dataset = FuseDatasetDefault(cache_dest=paths['cache_dir'], | ||
data_source=train_data_source, | ||
|
@@ -271,19 +253,19 @@ def run_train(paths: dict, train_common_params: dict): | |
|
||
model = FuseModelWrapper(model=torch_model, | ||
model_inputs=['data.input.input_0'], | ||
model_outputs=['logits.classification'] | ||
model_outputs=['logits.segmentation'] | ||
) | ||
|
||
lgr.info('Model: Done', {'attrs': 'bold'}) | ||
# ==================================================================================== | ||
# Loss | ||
# ==================================================================================== | ||
losses = { | ||
'dice_loss': DiceBCELoss(pred_name='model.logits.classification', | ||
'dice_loss': DiceBCELoss(pred_name='model.logits.segmentation', | ||
target_name='data.gt.gt_global') | ||
} | ||
|
||
model = model.cuda() | ||
# model = model.cuda() | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. remove |
||
optimizer = optim.SGD(model.parameters(), | ||
lr=train_common_params['manager.learning_rate'], | ||
momentum=0.9, | ||
|
@@ -363,7 +345,7 @@ def run_infer(paths: dict, infer_common_params: dict): | |
# Create visualizer (optional) | ||
visualiser = FuseVisualizerDefault(image_name='data.input.input_0', | ||
mask_name='data.gt.gt_global', | ||
pred_name='model.logits.classification') | ||
pred_name='model.logits.segmentation') | ||
|
||
infer_dataset = FuseDatasetDefault(cache_dest=paths['cache_dir'], | ||
data_source=infer_data_source, | ||
|
@@ -387,8 +369,8 @@ def run_infer(paths: dict, infer_common_params: dict): | |
|
||
#### Manager for inference | ||
manager = FuseManagerDefault() | ||
# extract just the global classification per sample and save to a file | ||
output_columns = ['model.logits.classification', 'data.gt.gt_global'] | ||
# extract just the global segmentation per sample and save to a file | ||
output_columns = ['model.logits.segmentation', 'data.gt.gt_global'] | ||
manager.infer(data_loader=infer_dataloader, | ||
input_model_dir=paths['model_dir'], | ||
checkpoint=infer_common_params['checkpoint'], | ||
|
@@ -398,7 +380,7 @@ def run_infer(paths: dict, infer_common_params: dict): | |
# visualize the predictions | ||
infer_processor = FuseProcessorDataFrame(data_pickle_filename=infer_common_params['infer_filename']) | ||
descriptors_list = infer_processor.get_samples_descriptors() | ||
out_name = 'model.logits.classification' | ||
out_name = 'model.logits.segmentation' | ||
gt_name = 'data.gt.gt_global' | ||
for desc in descriptors_list[:10]: | ||
data = infer_processor(desc) | ||
|
@@ -413,32 +395,32 @@ def run_infer(paths: dict, infer_common_params: dict): | |
plt.savefig(fn) | ||
|
||
###################################### | ||
# Analyze Common Params | ||
# Evaluation Common Params | ||
###################################### | ||
ANALYZE_COMMON_PARAMS = {} | ||
ANALYZE_COMMON_PARAMS['infer_filename'] = INFER_COMMON_PARAMS['infer_filename'] | ||
ANALYZE_COMMON_PARAMS['output_filename'] = os.path.join(PATHS['analyze_dir'], 'all_metrics.txt') | ||
ANALYZE_COMMON_PARAMS['num_workers'] = 4 | ||
ANALYZE_COMMON_PARAMS['batch_size'] = 8 | ||
EVAL_COMMON_PARAMS = {} | ||
EVAL_COMMON_PARAMS['infer_filename'] = INFER_COMMON_PARAMS['infer_filename'] | ||
EVAL_COMMON_PARAMS['output_filename'] = os.path.join(PATHS['eval_dir'], 'all_metrics.txt') | ||
EVAL_COMMON_PARAMS['num_workers'] = 4 | ||
EVAL_COMMON_PARAMS['batch_size'] = 8 | ||
|
||
###################################### | ||
# Analyze Template | ||
###################################### | ||
def run_analyze(paths: dict, analyze_common_params: dict): | ||
def run_eval(paths: dict, eval_common_params: dict): | ||
fuse_logger_start(output_path=None, console_verbose_level=logging.INFO) | ||
lgr = logging.getLogger('Fuse') | ||
lgr.info('Fuse Analyze', {'attrs': ['bold', 'underline']}) | ||
lgr.info('Fuse eval', {'attrs': ['bold', 'underline']}) | ||
|
||
# define iterator | ||
def data_iter(): | ||
data = pd.read_pickle(analyze_common_params['infer_filename']) | ||
data = pd.read_pickle(eval_common_params['infer_filename']) | ||
n_samples = data.shape[0] | ||
threshold = 1e-7 #0.5 | ||
for inx in range(n_samples): | ||
row = data.loc[inx] | ||
sample_dict = {} | ||
sample_dict["id"] = row['id'] | ||
sample_dict["pred.array"] = row['model.logits.classification'] > threshold | ||
sample_dict["pred.array"] = row['model.logits.segmentation'] > threshold | ||
sample_dict["label.array"] = row['data.gt.gt_global'] | ||
yield sample_dict | ||
|
||
|
@@ -449,7 +431,7 @@ def data_iter(): | |
("PixelAcc", MetricPixelAccuracy(pred='pred.array', target='label.array')), | ||
]) | ||
|
||
# create analyzer | ||
# create evaluator | ||
evaluator = EvaluatorDefault() | ||
|
||
results = evaluator.eval(ids=None, | ||
|
@@ -470,8 +452,8 @@ def data_iter(): | |
force_gpus = None # [0] | ||
FuseUtilsGPU.choose_and_enable_multiple_gpus(NUM_GPUS, force_gpus=force_gpus) | ||
|
||
RUNNING_MODES = ['train', 'infer', 'analyze'] # Options: 'train', 'infer', 'analyze' | ||
# RUNNING_MODES = ['analyze'] # Options: 'train', 'infer', 'analyze' | ||
RUNNING_MODES = ['train', 'infer', 'eval'] # Options: 'train', 'infer', 'eval' | ||
# RUNNING_MODES = ['eval'] # Options: 'train', 'infer', 'eval' | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. remove |
||
|
||
# train | ||
if 'train' in RUNNING_MODES: | ||
|
@@ -481,7 +463,7 @@ def data_iter(): | |
if 'infer' in RUNNING_MODES: | ||
run_infer(paths=PATHS, infer_common_params=INFER_COMMON_PARAMS) | ||
|
||
# analyze | ||
if 'analyze' in RUNNING_MODES: | ||
run_analyze(paths=PATHS, analyze_common_params=ANALYZE_COMMON_PARAMS) | ||
# eval | ||
if 'eval' in RUNNING_MODES: | ||
run_eval(paths=PATHS, eval_common_params=EVAL_COMMON_PARAMS) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why did you change it? the previous implementation didn't work?
I see that you are not using channels - can you support it as well?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I find it more straightforward and easy to use.