A personal project to learn about semantic segmentation using PyTorch.
Implementation of the UNet architecture as described in Ronneberger et al. for task of semantic segmentation of the Humerus Bone using X-Ray images as the input modality.
Reproduction of results is carried out using a subset of the MURA dataset.
The main code is located in the train.py file. All other code files are imported into train.py for training and testing the model.
The code to perform segmentation on custom images is present in predict.py file.
For your reference, the UNet architecture diagram (from Ronneberger et al.) is attached below.
The images obtained from MURA had the X-Ray images included, without the ground truth segmentation masks.
Hence, ground truth annotations were created using the LabelMe software. The created masks were later validated by medical professionals.
These instructions will get you a copy of the project up and running on your local machine.
The project is structured as follows:
UNet-PyTorch/
├── config/
│ └── __init__.py
├── data/
│ ├── train/
│ │ ├── images/
│ │ └── masks/
│ └── test/
│ ├── images/
│ └── masks/
├── models/
├── requirements.txt
├── utils.py
├── data_setup.py
├── model_builder.py
├── engine.py
├── train.py
└── predict.py
Ensure that your directory structure abides by the structure mentioned above. Especially, make sure your data folder is structured in the format mentioned above. For your reference, an empty data
directory following this structure is placed in this project.
You need to have a machine with Python > 3.6 and any Bash based shell (e.g. zsh) installed:
$ python3.8 -V
Python 3.8.18
$ echo $SHELL
/bin/zsh
Clone the repository:
$ git clone https://github.com/ajaystar8/UNet-PyTorch.git
Install requirements using in a new conda environment:
$ conda create -n name_of_env python=3.8 --file requirements.txt
Navigate to the config package and specify the following:
- Path to your
data
directory. - Path to
models
directory for saving model checkpoints. - Change other
hyperparameters
if necessary.
Activate the conda environment:
$ conda activate name_of_env
To start training the model, you can call the train.py script. Efforts have been taken to ensure that most of the parameters and hyperparameters to train and test the model can be set manually. You can get the list of command line arguments that can be toggled by executing the command:
$ python3 train.py --help
usage: train.py [-h] [--wandb_api_key WANDB_API_KEY] [-v VERBOSITY] [--input_dims H W] [--epochs NUM_EPOCHS] [--batch_size N] [--loss_fn LOSS_FUNCTION] [--learning_rate LR]
[--exp_track TRACK_EXPERIMENT] [--in_channels IN_C] [--out_channels OUT_C]
DATA_DIR CHECKPOINT_DIR RUN_NAME DATASET_NAME
Script to begin training and validation of UNet.
positional arguments:
DATA_DIR path to dataset directory
CHECKPOINT_DIR path to directory storing model checkpoints
RUN_NAME Name of current run
DATASET_NAME Name of dataset over which model is to be trained
optional arguments:
-h, --help show this help message and exit
--wandb_api_key WANDB_API_KEY
API key of your Weights and Biases Account.
-v VERBOSITY, --verbose VERBOSITY
setting verbosity to 1 will send email alerts to user after every epoch (default: 0)
Hyperparameters for model training:
--input_dims H W spatial dimensions of input image (default: [256, 256])
--epochs NUM_EPOCHS number of epochs to train (default: 10)
--batch_size N number of images per batch (default: 1)
--loss_fn LOSS_FUNCTION
Loss function for model training (default: BCELoss)
--learning_rate LR learning rate for training (default: 0.0001)
--exp_track TRACK_EXPERIMENT
'true' if you want to track experiments using wandb. Defaults to 'false'
Architecture parameters:
--in_channels IN_C number of channels in input image (default: 1)
--out_channels OUT_C number of classes in ground truth mask (default: 1)
Happy training! :)
The command shown below is an example of a call that can be used to train the model.
$ python3 train.py --verbose 0 --input_dims 256 256 --epochs 30 --batch_size 2 --loss_fn BCEWithLogitsLoss --learning_rate 1e-5 --exp_track false --in_channels 1 --out_channels 1 ./data ./models sample_run MURA
Read the TODO to see the current task list.
I have tried to adopt good coding practices as mentioned in different blogs and articles. However, I feel there is still a lot of room for improvement in making the code more efficient, modular and easy to understand.
I would be thankful if you could share your opinions by opening a GitHub Issue for the same. Your criticisms are always welcome!
This project is licensed under the Apache License - see the LICENSE file for details.