Multiclass semantic segmentation on cityscapes and kitti datasets.
Semantic segmentation is no more than pixel-level classification and is well-known in the deep-learning community. There are several "state of the art" approaches for building such models. So basically we need a fully-convolutional network with some pretrained backbone for feature extraction to "map" input image with given masks (let's say, each output channel represents the individual class).
Here is an example of cityscapes annotation:
At this repo I want to show a way to train two most popular architectures - UNET and FPN (with pretty large resnext50
encoders).
Also, I want to give an idea of where we can use these semantic masks in the self-driving/robotics field: one of the use cases can be generating "prior" for point cloud clustering algorithms. But you can ask a question: why is semantic segmentation when in this case it's better to use panoptic/instance segmentation? Well, my answer will be: semantic segmentation models are a lot simpler and faster to understand and train.
Both UNET and FPN uses features from the different scales and I'll quote really insightful words from the web about the difference between UNet and FPN:
...
The main difference is that there is multiple prediction layers: one for each upsampling layer.
Like the U-Net, the FPN has laterals connection between the bottom-up pyramid (left) and the top-down pyramid (right).
But, where U-net only copy the features and append them, FPN apply a 1x1 convolution layer before adding them.
This allows the bottom-up pyramid called “backbone” to be pretty much whatever you want.
...
Check out the UNET paper, which also gives the idea on separating instances (with borders predictions).
And this presentation from the FPN paper authors.
To get model's short summary, I prefer using torchsummary.
Torchsummary lib may require little hack to be able to work with FPN implementation.
Make the folowing edits into the torchsummary.py
:
...
try:
summary[m_key]["input_shape"] = list(input[0].size())
except:
summary[m_key]["input_shape"] = list(input[0][0].size())
...
Example (you will see the results in the stdout):
python model_summary.py \
--model_type fpn \
--backbone resnext50 \
--unet_res_blocks 0 \
--input_size 3,512,1024
--model_type unet --backbone resnext50 --unet_res_blocks 1 --input_size 3,512,1024
:
Total params: 92,014,472
Forward/backward pass size (MB): 10082.50
Params size (MB): 351.01
--model_type unet --backbone resnext50 --unet_res_blocks 0 --input_size 3,512,1024
:
Total params: 81,091,464
Forward/backward pass size (MB): 9666.50
Params size (MB): 309.34
--model_type fpn --backbone resnext50 --input_size 3,512,1024
:
Total params: 25,588,808
Forward/backward pass size (MB): 4574.11
Params size (MB): 97.61
As we can see, FPN segmentation model is a lot "lighter" (so we can make larger batch size ;) ).
To monitor the training process, we can set up tensorboard (on CPU):
CUDA_VISIBLE_DEVICES="" tensorboard --logdir /samsung_drive/semantic_segmentation/%MDOEL_DIR%/tensorboard
Logs are sending from the main training loop in Trainer
class.
Here are examples of typical training process (Unet for all cityscapes classes):
Core functionality is implemented in Trainer
class (./utils/trainer.py
).
The training/evaluating pipeline is straight-forward: you just need to fill out the train_config.yaml
according to the selected model and dataset.
Let's take a look at the each module in training/evaluation configuration:
TARGET
- the dataset you want to train on - it affect the preprocessing stage (cityscapes
orkitti
);PATH
- it's just the paths to train, validation and test datasets;DATASET
- here we must set the image sizes, select classes to train and control augmentations;MODEL
- declare model type, backbone and number of classes (affects the last layers of network);TRAINING
- here are all training process properties: GPUs, loss, metric, class weights and etc.;EVAL
- paths to store the predictions, test-time augmentation, flag, thresholds and etc.;
There is an example of config file in the root dir.
To use different configs, just pass them to the train/eval scripts as arguments:
python train.py --config_path ./train_config.yaml
For training, I've mostly used a fine-annotated part of the cityscapes dataset (~3k examples). It also exists a large amount of coarse-annotated data, and it is obviously sufficient for pre-training, but I didn't consider this part of the dataset in order to save time on training.
After training on the cityscapes dataset (in case of road segmentation), you can easily use this model as initialization for the Kitti dataset to segment road/lanes.
The cityscapes dataset also gives you a choice to use all classes or categories - as classes aggregated by certain properties.
I've implemented all dataset-specific preprocessing in cityscapes_utils.py
and kitti_lane_utils.py
scripts.
I used the dice loss (which is equivalent to the F1 score) as a default metric. In segmentation problems, it's usually applied intersection over union and dice metrics for evaluation. They're positively correlated, but the dice coefficient tends to measure some average performance for all classes and examples. Here is a nice visualization of IoU (on the left) and dice (on the right):
Now about the loss - I used weighted sum of binary cross-entropy and dice loss, as in binary classification here. You may also easily use IoU instead of dice since their implementation is very similar.
BCE was calculated on logits for numerical stability.
Augmentation is a well-established technique for dataset extension. What we do, is slightly modifying both the image and the mask. Here, I apply augmentations "on the fly" along with the batch generation, via the best-known library albumentations.
Usually, I end up with some mix of a couple spatial and RGB augmentations: like crops/flip + random contrast/brightness (you can check out it in ./utils/cityscapes_utils.py
).
Also, sometimes you want to apply really hard augs, to imitate images from other "conditions distribution", like snow, shadows, rain and etc. Albumentations gives you that possibility via this code.
Here is an original image:
Here is a "darkened" flipped and slightly cropped image:
This is a previous augmented image with random rain, light beam, and snow:
Another way to use augmentations to increase model performance is to apply some "soft" deterministic affine transformations like flips and then average the results of the predictions (I've read the great analogy on how a human can look at the image from different angles and better understand what is shown there).
This process called test-time augmentation or simply TTA. The bad thing is that we need to make predictions for each transform, which leads to larger inference time. Here is some visual explanation on how this works:
tsharpen is just (x_0^t + ... +x_i^t)/N
I use simple arithmetic mean, but you can try, for instance, geometric mean, tsharpen and etc. Check the code here: /utils/TTA.py
.
As the post-processing step, I detect and replace clusters of a certain area with background class, which leads to a "jitter" effect on a small and far situated masks (check out /utils/utils.py-->DropClusters
).
Here I took the best of the 40 epochs of training on 2x down-sized images (512x1024) for 2 and 8 classes and 8x down-sized images for 20 classes (to fit the batch into GPU's memory).
Models for 2-8 classes were trained in two stages: on smaller images at first - 256x512 and then only 2x resized - 512x1024.
Dice metric comparison table:
Classes # | Unet | FPN | Size |
---|---|---|---|
2 (road segmentation) | 0.956 (weights) | 0.956 (weights) | 256x512 >> 512x1024 |
8 (categories only) | 0.929 (weights) | 0.931 (weights) | 256x512 >> 512x1024 |
20 | 0.852 (weights) | 0.858 (weights) | 128x256 |
8 classes:
Model | void | flat | construction | object | nature | sky | human | vehicle |
---|---|---|---|---|---|---|---|---|
FPN | 0.769 | 0.954 | 0.889 | 0.573 | 0.885 | 0.804 | 0.492 | 0.897 |
UNET | 0.750 | 0.958 | 0.888 | 0.561 | 0.884 | 0.806 | 0.479 | 0.890 |
20 classes:
Model | road | sidewalk | building | wall | fence | pole | traffic light | traffic sign | vegetation | terrain | sky | person | rider | car | truck | bus | train | motorcycle | bicycle | unlabeled |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
FPN | 0.943 | 0.562 | 0.777 | 0.011 | 0.046 | 0.041 | 0.318 | 0.128 | 0.808 | 0.178 | 0.747 | 0.132 | 0.0132 | 0.759 | 0.010 | 0.022 | 0.013 | 0.005 | 0.072 | 0.216 |
UNET | 0.944 | 0.608 | 0.785 | 0.020 | 0.017 | 0.131 | 0.321 | 0.161 | 0.822 | 0.236 | 0.765 | 0.141 | 0.000 | 0.780 | 0.001 | 0.002 | 0.001 | 0.000 | 0.056 | 0.112 |
So what is interesting, that I expected to see better performance on multiclass problems by FPN architecture, but the thing is on average both UNET and FPN gives pretty close dice metric.
Yes, there are a couple of classes that the FPN segmentation model detects better (marked in the table), but the absolute dice metric values of such classes, are not so high.
Summary:
In general, if you're dealing with some generic segmentation problem with pretty large, nicely separable objects - it seems that the FPN could be a good choice for both binary and multiclass segmentation in terms of segmentation quality and computational effectiveness, but at the same time I've noticed that FPN gives more small gapes in masks opposite to the UNET. Check out videos below:
Prediction on cityscapes demo videos (Stuttgart):
Classes # | UNET | FPN |
---|---|---|
2 | 00, 01, 02 | 00, 01, 02 |
8 | 00, 01, 02 | 00, 01, 02 |
20 | 00, 01, 02 | 00, 01, 02 |
I used ffmpeg
for making videos from images sequence on Linux:
ffmpeg -f image2 -framerate 20 \
-pattern_type glob -i 'stuttgart_00_*.png' \
-c:v libx264 -pix_fmt yuv420p ../stuttgart_00.mp4
Check out this awesome repo with high-quality implementations of the basic semantic segmentation algorithms.
Again, I strongly suggest to use Deepo as a simple experimental enviroment.
When you've done with your code - better build your own docker container and keep the last version on somewhere like dockerhub.
Anyway, here are some key dependencies for these repo:
pip install --upgrade tqdm \
torchsummary \
tensorboardX \
albumentations==0.4.1 \
torch==1.1.0 \
torchvision==0.4.0