Skip to content

Simple implementation of UNet architecture in PyTorch

License

Notifications You must be signed in to change notification settings

ajaystar8/UNet-PyTorch

Repository files navigation

UNet Paper Implementation using PyTorch

GitHub license

Table of Contents

About

A personal project to learn about semantic segmentation using PyTorch.

Implementation of the UNet architecture as described in Ronneberger et al. for task of semantic segmentation of the Humerus Bone using X-Ray images as the input modality.

Reproduction of results is carried out using a subset of the MURA dataset.

The main code is located in the train.py file. All other code files are imported into train.py for training and testing the model.

The code to perform segmentation on custom images is present in predict.py file.

For your reference, the UNet architecture diagram (from Ronneberger et al.) is attached below.

UNet Architecture Diagram

Dataset

The images obtained from MURA had the X-Ray images included, without the ground truth segmentation masks.

Hence, ground truth annotations were created using the LabelMe software. The created masks were later validated by medical professionals.

Getting Started

These instructions will get you a copy of the project up and running on your local machine.

Project Structure

The project is structured as follows:

UNet-PyTorch/
├── config/
│   └── __init__.py
├── data/
│   ├── train/
│   │   ├── images/
│   │   └── masks/
│   └── test/
│       ├── images/
│       └── masks/
├── models/
├── requirements.txt
├── utils.py
├── data_setup.py
├── model_builder.py
├── engine.py
├── train.py
└── predict.py     

Ensure that your directory structure abides by the structure mentioned above. Especially, make sure your data folder is structured in the format mentioned above. For your reference, an empty data directory following this structure is placed in this project.

Prerequisites

You need to have a machine with Python > 3.6 and any Bash based shell (e.g. zsh) installed:

$ python3.8 -V
Python 3.8.18

$ echo $SHELL
/bin/zsh

Installing the Requirements

Clone the repository:

$ git clone https://github.com/ajaystar8/UNet-PyTorch.git

Install requirements using in a new conda environment:

$ conda create -n name_of_env python=3.8 --file requirements.txt

Running the Code

Navigate to the config package and specify the following:

  • Path to your data directory.
  • Path to models directory for saving model checkpoints.
  • Change other hyperparameters if necessary.

Activate the conda environment:

$ conda activate name_of_env

To start training the model, you can call the train.py script. Efforts have been taken to ensure that most of the parameters and hyperparameters to train and test the model can be set manually. You can get the list of command line arguments that can be toggled by executing the command:

$ python3 train.py --help

usage: train.py [-h] [--wandb_api_key WANDB_API_KEY] [-v VERBOSITY] [--input_dims H W] [--epochs NUM_EPOCHS] [--batch_size N] [--loss_fn LOSS_FUNCTION] [--learning_rate LR]
                [--exp_track TRACK_EXPERIMENT] [--in_channels IN_C] [--out_channels OUT_C]
                DATA_DIR CHECKPOINT_DIR RUN_NAME DATASET_NAME

Script to begin training and validation of UNet.

positional arguments:
  DATA_DIR              path to dataset directory
  CHECKPOINT_DIR        path to directory storing model checkpoints
  RUN_NAME              Name of current run
  DATASET_NAME          Name of dataset over which model is to be trained

optional arguments:
  -h, --help            show this help message and exit
  --wandb_api_key WANDB_API_KEY
                        API key of your Weights and Biases Account.
  -v VERBOSITY, --verbose VERBOSITY
                        setting verbosity to 1 will send email alerts to user after every epoch (default: 0)

Hyperparameters for model training:
  --input_dims H W      spatial dimensions of input image (default: [256, 256])
  --epochs NUM_EPOCHS   number of epochs to train (default: 10)
  --batch_size N        number of images per batch (default: 1)
  --loss_fn LOSS_FUNCTION
                        Loss function for model training (default: BCELoss)
  --learning_rate LR    learning rate for training (default: 0.0001)
  --exp_track TRACK_EXPERIMENT
                        'true' if you want to track experiments using wandb. Defaults to 'false'

Architecture parameters:
  --in_channels IN_C    number of channels in input image (default: 1)
  --out_channels OUT_C  number of classes in ground truth mask (default: 1)

Happy training! :)

The command shown below is an example of a call that can be used to train the model.

$ python3 train.py --verbose 0 --input_dims 256 256 --epochs 30 --batch_size 2 --loss_fn BCEWithLogitsLoss --learning_rate 1e-5 --exp_track false --in_channels 1 --out_channels 1 ./data ./models sample_run MURA 

TODO

Read the TODO to see the current task list.

A Kind Request

I have tried to adopt good coding practices as mentioned in different blogs and articles. However, I feel there is still a lot of room for improvement in making the code more efficient, modular and easy to understand.

I would be thankful if you could share your opinions by opening a GitHub Issue for the same. Your criticisms are always welcome!

License

This project is licensed under the Apache License - see the LICENSE file for details.

Releases

No releases published

Packages

No packages published

Languages