Skip to content

Commit

Permalink
Transunet model, scat coeff, and demo code commit
Browse files Browse the repository at this point in the history
  • Loading branch information
rupakl committed Apr 30, 2021
1 parent 1b1f73e commit 3bddc63
Show file tree
Hide file tree
Showing 10 changed files with 1,157 additions and 47 deletions.
321 changes: 321 additions & 0 deletions InferenceDemo.ipynb

Large diffs are not rendered by default.

137 changes: 137 additions & 0 deletions datasets/dataset_us_xray.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
import os
import random
import h5py
import numpy as np
import torch
from scipy import ndimage
from scipy.ndimage.interpolation import zoom
from torch.utils.data import Dataset
import cv2


def random_rot_flip(image, label):
k = np.random.randint(0, 4)
image = np.rot90(image, k)
label = np.rot90(label, k)
axis = np.random.randint(0, 2)
image = np.flip(image, axis=axis).copy()
label = np.flip(label, axis=axis).copy()
return image, label


def random_rotate(image, label):
angle = np.random.randint(-20, 20)
image = ndimage.rotate(image, angle, order=0, reshape=False)
label = ndimage.rotate(label, angle, order=0, reshape=False)
return image, label


class RandomGenerator(object):
def __init__(self, output_size):
self.output_size = output_size

def __call__(self, sample):
image, label = sample['image'], sample['label']

if random.random() > 0.5:
image, label = random_rot_flip(image, label)
elif random.random() > 0.5:
image, label = random_rotate(image, label)
x, y = image.shape
if x != self.output_size[0] or y != self.output_size[1]:
image = zoom(image, (self.output_size[0] / x, self.output_size[1] / y), order=3)
label = zoom(label, (self.output_size[0] / x, self.output_size[1] / y), order=0)
image = torch.from_numpy(image.astype(np.float32)).unsqueeze(0)
label = torch.from_numpy(label.astype(np.float32))
sample = {'image': image, 'label': label.long()}
return sample

def test_transform(sample):
image, label = sample['image'], sample['label']
x, y = image.shape
output_size = [224, 224]
if x != output_size[0] or y != output_size[1]:
image = zoom(image, (output_size[0] / x, output_size[1] / y), order=3)
label = zoom(label, (output_size[0] / x, output_size[1] / y), order=0)
image = torch.from_numpy(image.astype(np.float32)).unsqueeze(0)
label = torch.from_numpy(label.astype(np.float32))
sample = {'image': image, 'label': label.long()}
return sample

def normalize_image(image, min_val=-125, max_val=275):
image = (image - min_val) / (max_val - min_val)
image[image>1] = 1
image[image<0] = 0
return image

class Ultrasound_dataset(Dataset):
def __init__(self, base_dir, list_dir, split, transform=None):
self.transform = transform
self.split = split
files = {'train': 'ultrasound_train_list.txt', 'test_vol': 'ultrasound_val_list.txt'}
self.sample_list = open(os.path.join(list_dir, files[self.split])).readlines()
self.data_dir = base_dir

def __len__(self):
return len(self.sample_list)

def __getitem__(self, idx):
if self.split == "train":
slice_name = self.sample_list[idx].strip('\n')
img_path = os.path.join(self.data_dir, 'training_set', slice_name)
label_path = os.path.join(self.data_dir, 'training_set_masks', os.path.splitext(slice_name)[0]+'_mask.png')
image = cv2.imread(img_path)
image = image[:,:,0]
image = normalize_image(image, min_val=0, max_val=255)
label = cv2.imread(label_path)
label = (label[:,:,0] == 255).astype(np.uint8)
else:
slice_name = self.sample_list[idx].strip('\n')
img_path = os.path.join(self.data_dir, 'training_set', slice_name)
label_path = os.path.join(self.data_dir, 'training_set_masks', os.path.splitext(slice_name)[0]+'_mask.png')
image = cv2.imread(img_path)
image = image[:,:,0]
image = normalize_image(image, min_val=0, max_val=255)
label = cv2.imread(label_path)
label = (label[:,:,0] == 255).astype(np.uint8)
image = image[np.newaxis, :, :]
label = label[np.newaxis, :, :]

sample = {'image': image, 'label': label}
if self.transform:
sample = self.transform(sample)
sample['case_name'] = self.sample_list[idx].strip('\n')
return sample

class LungXray_dataset(Dataset):
def __init__(self, base_dir, list_dir, split, transform=None):
self.transform = transform
self.split = split
files = {'train': 'covid_lungs_seg_train_list.txt', 'test_vol': 'covid_lungs_seg_val_list.txt', 'test':'covid_lungs_seg_test_list.txt'}
self.sample_list = open(os.path.join(list_dir, files[self.split])).readlines()
self.data_dir = base_dir

def __len__(self):
return len(self.sample_list)

def __getitem__(self, idx):
if self.split == "train":
slice_name = self.sample_list[idx].strip('\n')
img_path = os.path.join(self.data_dir, 'images_bcet', slice_name)
label_path = os.path.join(self.data_dir, 'masks_resized', slice_name)
image = np.load(img_path)
label = np.load(label_path)
else:
slice_name = self.sample_list[idx].strip('\n')
img_path = os.path.join(self.data_dir, 'images_bcet', slice_name)
label_path = os.path.join(self.data_dir, 'masks_resized', slice_name)
image = np.load(img_path)
label = np.load(label_path)
image = image[np.newaxis, :, :]
label = label[np.newaxis, :, :]

sample = {'image': image, 'label': label}
if self.transform:
sample = self.transform(sample)
sample['case_name'] = self.sample_list[idx].strip('\n')
return sample
139 changes: 139 additions & 0 deletions datasets/dataset_us_xray_scat.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,139 @@
import os
import random
import h5py
import numpy as np
import torch
from scipy import ndimage
from scipy.ndimage.interpolation import zoom
from torch.utils.data import Dataset
import cv2
import scipy.io


def random_rot_flip(image, label):
k = np.random.randint(0, 4)
image = np.rot90(image, k)
label = np.rot90(label, k)
axis = np.random.randint(0, 2)
image = np.flip(image, axis=axis).copy()
label = np.flip(label, axis=axis).copy()
return image, label


def random_rotate(image, label):
angle = np.random.randint(-20, 20)
image = ndimage.rotate(image, angle, order=0, reshape=False)
label = ndimage.rotate(label, angle, order=0, reshape=False)
return image, label


class RandomGenerator(object):
def __init__(self, output_size):
self.output_size = output_size

def __call__(self, sample):
image, label = sample['image'], sample['label']

if random.random() > 0.5:
image, label = random_rot_flip(image, label)
elif random.random() > 0.5:
image, label = random_rotate(image, label)
x, y = image.shape
if x != self.output_size[0] or y != self.output_size[1]:
image = zoom(image, (self.output_size[0] / x, self.output_size[1] / y), order=3)
label = zoom(label, (self.output_size[0] / x, self.output_size[1] / y), order=0)
image = torch.from_numpy(image.astype(np.float32)).unsqueeze(0)
label = torch.from_numpy(label.astype(np.float32))
sample = {'image': image, 'label': label.long()}
return sample

def normalize_image(image, min_val=-125, max_val=275):
image = (image - min_val) / (max_val - min_val)
image[image>1] = 1
image[image<0] = 0
return image

class Ultrasound_dataset(Dataset):
def __init__(self, base_dir, list_dir, split, transform=None):
self.transform = transform
self.split = split
files = {'train': 'ultrasound_train_list.txt', 'test_vol': 'ultrasound_val_list.txt'}
sample_list = open(os.path.join(list_dir, files[self.split])).readlines()
sample_list = [ sample.strip() for sample in sample_list]
files_to_exclude = set([ '186_HC.png', '346_HC.png', '628_2HC.png'])
self.sample_list = list(set(sample_list)-files_to_exclude)
self.data_dir = base_dir
self.scat_mat_dir = os.path.join(self.data_dir, 'mat_arrs')

def __len__(self):
return len(self.sample_list)

def __getitem__(self, idx):
if self.split == "train":
slice_name = self.sample_list[idx].strip('\n')
img_path = os.path.join(self.data_dir, 'training_set', slice_name)
label_path = os.path.join(self.data_dir, 'training_set_masks', os.path.splitext(slice_name)[0]+'_mask.png')
scat_mat_path = os.path.join(self.scat_mat_dir, os.path.splitext(slice_name)[0]+'.mat')
image = cv2.imread(img_path)
image = image[:,:,0]
image = normalize_image(image, min_val=0, max_val=255)
label = cv2.imread(label_path)
label = (label[:,:,0] == 255).astype(np.uint8)
else:
slice_name = self.sample_list[idx].strip('\n')
img_path = os.path.join(self.data_dir, 'training_set', slice_name)
label_path = os.path.join(self.data_dir, 'training_set_masks', os.path.splitext(slice_name)[0]+'_mask.png')
scat_mat_path = os.path.join(self.scat_mat_dir, os.path.splitext(slice_name)[0]+'.mat')
image = cv2.imread(img_path)
image = image[:,:,0]
image = normalize_image(image, min_val=0, max_val=255)
label = cv2.imread(label_path)
label = (label[:,:,0] == 255).astype(np.uint8)
image = image[np.newaxis, :, :]
label = label[np.newaxis, :, :]

scat_mat = scipy.io.loadmat(scat_mat_path)

sample = {'image': image, 'label': label}
if self.transform:
sample = self.transform(sample)
sample['case_name'] = self.sample_list[idx].strip('\n')
sample['scat_mat'] = torch.from_numpy(scat_mat['S'])
return sample

class LungXray_dataset(Dataset):
def __init__(self, base_dir, list_dir, split, transform=None):
self.transform = transform
self.split = split
files = {'train': 'covid_lungs_seg_train_list.txt', 'test_vol': 'covid_lungs_seg_val_list.txt', 'test':'covid_lungs_seg_test_list.txt'}
self.sample_list = open(os.path.join(list_dir, files[self.split])).readlines()
self.data_dir = base_dir
self.scat_mat_dir = os.path.join(self.data_dir, 'mat_arrs')

def __len__(self):
return len(self.sample_list)

def __getitem__(self, idx):
if self.split == "train":
slice_name = self.sample_list[idx].strip('\n')
img_path = os.path.join(self.data_dir, 'images_bcet', slice_name)
label_path = os.path.join(self.data_dir, 'masks_resized', slice_name)
scat_mat_path = os.path.join(self.scat_mat_dir, os.path.splitext(slice_name)[0]+'.mat')
image = np.load(img_path)
label = np.load(label_path)
else:
slice_name = self.sample_list[idx].strip('\n')
img_path = os.path.join(self.data_dir, 'images_bcet', slice_name)
label_path = os.path.join(self.data_dir, 'masks_resized', slice_name)
scat_mat_path = os.path.join(self.scat_mat_dir, os.path.splitext(slice_name)[0]+'.mat')
image = np.load(img_path)
label = np.load(label_path)
image = image[np.newaxis, :, :]
label = label[np.newaxis, :, :]
scat_mat = scipy.io.loadmat(scat_mat_path)
sample = {'image': image, 'label': label}
if self.transform:
sample = self.transform(sample)
sample['case_name'] = self.sample_list[idx].strip('\n')
sample['scat_mat'] = torch.from_numpy(scat_mat['S'])
return sample
35 changes: 8 additions & 27 deletions eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,7 @@
from scipy.ndimage import zoom
from datasets.dataset_us_xray import Ultrasound_dataset, LungXray_dataset
from utils import calculate_metric_percase
from networks.vit_seg_modeling import VisionTransformer as ViT_seg
from networks.vit_seg_modeling import CONFIGS as CONFIGS_ViT_seg
from networks.TransUNet_model import TransUNet

cudnn.benchmark = False
cudnn.deterministic = True
Expand All @@ -26,22 +25,13 @@
parser.add_argument('--root_path', type=str,
default='', help='root dir for validation volume data')
parser.add_argument('--dataset', type=str,
default='Ultrasound', help='experiment_name')
default='Ultrasound', help='name of dataset for training')
parser.add_argument('--num_classes', type=int,
default=4, help='output channel of network')
parser.add_argument('--list_dir', type=str,
default='./lists/lists_Ultrasound', help='list dir')
default=2, help='number of classes including background')
parser.add_argument('--list_dir', type=str, default='./lists/lists_Ultrasound', help='path to dir where train-val split is stored')
parser.add_argument('--model_path', type=str,
required=True, help='path to trained model')
parser.add_argument('--img_size', type=int, default=224, help='input patch size of network input')

parser.add_argument('--n_skip', type=int, default=3, help='using number of skip-connect, default is num')
parser.add_argument('--vit_name', type=str, default='ViT-B_16', help='select one vit model')

parser.add_argument('--test_save_dir', type=str, default='../predictions', help='saving prediction as nii!')
parser.add_argument('--deterministic', type=int, default=1, help='whether use deterministic training')
parser.add_argument('--base_lr', type=float, default=0.01, help='segmentation network learning rate')
parser.add_argument('--vit_patches_size', type=int, default=16, help='vit_patches_size, default is 16')
args = parser.parse_args()

def test_single_volume(image, label, net, classes, patch_size=[256, 256], case=None):
Expand All @@ -52,7 +42,7 @@ def test_single_volume(image, label, net, classes, patch_size=[256, 256], case=N
slice = image[ind, :, :]
x, y = slice.shape[0], slice.shape[1]
if x != patch_size[0] or y != patch_size[1]:
slice = zoom(slice, (patch_size[0] / x, patch_size[1] / y), order=3) # previous using 0
slice = zoom(slice, (patch_size[0] / x, patch_size[1] / y), order=3)
input = torch.from_numpy(slice).unsqueeze(0).unsqueeze(0).float().cuda()
net.eval()
with torch.no_grad():
Expand Down Expand Up @@ -108,7 +98,7 @@ def inference(args, model):
'list_dir': './lists/lists_Ultrasound',
'num_classes': 2,
},
'CovidLungSeg': {
'LungSeg': {
'Dataset': LungXray_dataset,
'root_path': '/ssd_scratch/cvit/rupraze/data/lungs_seg_dataset',
'list_dir': './lists/lists_CovidLungSeg',
Expand All @@ -120,19 +110,10 @@ def inference(args, model):
args.root_path = dataset_config[dataset_name]['root_path']
args.Dataset = dataset_config[dataset_name]['Dataset']
args.list_dir = dataset_config[dataset_name]['list_dir']
args.z_spacing = dataset_config[dataset_name]['z_spacing']
args.is_pretrain = True

# name the same snapshot defined in train script!
args.exp = 'TU_' + dataset_name + str(args.img_size)

config_vit = CONFIGS_ViT_seg[args.vit_name]
config_vit.n_classes = args.num_classes
config_vit.n_skip = args.n_skip
config_vit.patches.size = (args.vit_patches_size, args.vit_patches_size)
config_vit.patches.grid = (int(args.img_size/args.vit_patches_size), int(args.img_size/args.vit_patches_size))
net = ViT_seg(config_vit, img_size=args.img_size, num_classes=config_vit.n_classes).cuda()
args.exp = dataset_name + str(args.img_size)

net = TransUNet(num_classes=args.num_classes).cuda()
if not os.path.exists(args.model_path):
print("Model path doesn't exist \nExiting eval process.")
exit()
Expand Down
Loading

0 comments on commit 3bddc63

Please sign in to comment.