forked from wolny/pytorch-3dunet
-
Notifications
You must be signed in to change notification settings - Fork 0
/
train_config.yml
156 lines (151 loc) · 5.34 KB
/
train_config.yml
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
# Sample configuration file for training a 3D U-Net on a task of predicting the boundaries in 3D stack of the Arabidopsis lateral root
# acquired with the lightsheet microscope. Training done with a combination of Binary Cross-Entropy and DiceLoss.
# Download training data from: https://osf.io/9x3g2/
# Download validation data from: https://osf.io/vs6gb/
# Download test data from: https://osf.io/tn4xj/
model:
name: ResidualUNet3D
# number of input channels to the model
in_channels: 1
# number of output channels
out_channels: 1
# determines the order of operators in a single layer (crg - Conv3d+ReLU+GroupNorm)
layer_order: gcr
# initial number of feature maps
f_maps: 32
# number of groups in the groupnorm
num_groups: 8
# apply element-wise nn.Sigmoid after the final 1x1x1 convolution, otherwise apply nn.Softmax
final_sigmoid: true
# loss function to be used during training
loss:
name: BCEDiceLoss
# a target value that is ignored and does not contribute to the input gradient
ignore_index: null
# skip the last channel in the target (i.e. when last channel contains data not relevant for the loss)
skip_last_target: true
optimizer:
# initial learning rate
learning_rate: 0.0002
# weight decay
weight_decay: 0.00001
# evaluation metric
eval_metric:
# use AdaptedRandError metric
name: BoundaryAdaptedRandError
# probability maps threshold
threshold: 0.4
# use the last target channel to compute the metric
use_last_target: true
# use only the first channel for computing the metric
use_first_input: true
lr_scheduler:
name: ReduceLROnPlateau
# make sure to use the 'min' mode cause lower AdaptedRandError is better
mode: min
factor: 0.2
patience: 20
trainer:
# model with lower eval score is considered better
eval_score_higher_is_better: False
# path to the checkpoint directory
checkpoint_dir: CHECKPOINT_DIR
# path to latest checkpoint; if provided the training will be resumed from that checkpoint
resume: null
# path to the best_checkpoint.pytorch; to be used for fine-tuning the model with additional ground truth
# make sure to decrease the learning rate in the optimizer config accordingly
pre_trained: null
# how many iterations between validations
validate_after_iters: 1000
# how many iterations between tensorboard logging
log_after_iters: 500
# max number of epochs
max_num_epochs: 1000
# max number of iterations
max_num_iterations: 150000
# Configure training and validation loaders
loaders:
# how many subprocesses to use for data loading
num_workers: 8
# path to the raw data within the H5
raw_internal_path: /raw
# path to the the label data withtin the H5
label_internal_path: /label
# configuration of the train loader
train:
# path to the training datasets
file_paths:
- PATH_TO_TRAIN_DIR
# SliceBuilder configuration, i.e. how to iterate over the input volume patch-by-patch
slice_builder:
name: FilterSliceBuilder
# train patch size given to the network (adapt to fit in your GPU mem, generally the bigger patch the better)
patch_shape: [80, 170, 170]
# train stride between patches
stride_shape: [20, 40, 40]
# minimum volume of the labels in the patch
threshold: 0.6
# probability of accepting patches which do not fulfil the threshold criterion
slack_acceptance: 0.01
transformer:
raw:
- name: Standardize
- name: RandomFlip
- name: RandomRotate90
- name: RandomRotate
# rotate only in ZY plane due to anisotropy
axes: [[2, 1]]
angle_spectrum: 45
mode: reflect
- name: ElasticDeformation
spline_order: 3
- name: GaussianBlur3D
execution_probability: 0.5
- name: AdditiveGaussianNoise
execution_probability: 0.2
- name: AdditivePoissonNoise
execution_probability: 0.2
- name: ToTensor
expand_dims: true
label:
- name: RandomFlip
- name: RandomRotate90
- name: RandomRotate
# rotate only in ZY plane due to anisotropy
axes: [[2, 1]]
angle_spectrum: 45
mode: reflect
- name: ElasticDeformation
spline_order: 0
- name: StandardLabelToBoundary
# append original ground truth labels to the last channel (to be able to compute the eval metric)
append_label: true
- name: ToTensor
expand_dims: false
# configuration of the val loader
val:
# path to the val datasets
file_paths:
- PATH_TO_VAL_DIR
# SliceBuilder configuration, i.e. how to iterate over the input volume patch-by-patch
slice_builder:
name: FilterSliceBuilder
# train patch size given to the network (adapt to fit in your GPU mem, generally the bigger patch the better)
patch_shape: [80, 170, 170]
# train stride between patches
stride_shape: [80, 170, 170]
# minimum volume of the labels in the patch
threshold: 0.6
# probability of accepting patches which do not fulfil the threshold criterion
slack_acceptance: 0.01
# data augmentation
transformer:
raw:
- name: Standardize
- name: ToTensor
expand_dims: true
label:
- name: StandardLabelToBoundary
append_label: true
- name: ToTensor
expand_dims: false