diff --git a/README.txt b/README.txt new file mode 100644 index 0000000..e499be0 --- /dev/null +++ b/README.txt @@ -0,0 +1,2 @@ +将用于训练的四张大图(382、182以及他们的标签)放入data文件夹中,然后运行run.sh。 +注:run.sh文件包含切图、划分训练集、验证集(python cut_data.py)和训练模型(CUDA_VISIBLE_DEVICES=0 python train.py --backbone=hrnet --batchsize=4 --lr=0.01 --num_epochs=150) diff --git a/__pycache__/evaluate.cpython-36.pyc b/__pycache__/evaluate.cpython-36.pyc new file mode 100644 index 0000000..3785313 Binary files /dev/null and b/__pycache__/evaluate.cpython-36.pyc differ diff --git a/__pycache__/loss.cpython-36.pyc b/__pycache__/loss.cpython-36.pyc new file mode 100644 index 0000000..25c4088 Binary files /dev/null and b/__pycache__/loss.cpython-36.pyc differ diff --git a/cut_data.py b/cut_data.py new file mode 100644 index 0000000..edcc52b --- /dev/null +++ b/cut_data.py @@ -0,0 +1,103 @@ +import os +import numpy as np +from PIL import Image +import cv2 as cv +from tqdm import tqdm +import random +import shutil +Image.MAX_IMAGE_PIXELS = 1000000000000000 +TARGET_W, TARGET_H = 1024, 1024 + + +def cut_images(image_name, image_path, label_path, save_dir, is_show=True): + # 初始化路径 + image_save_dir = os.path.join(save_dir, "images/"+image_name.split(".")[0]) + if not os.path.exists(image_save_dir): os.makedirs(image_save_dir) + label_save_dir = os.path.join(save_dir, "labels/"+image_name.split(".")[0]) + if not os.path.exists(label_save_dir): os.makedirs(label_save_dir) + if is_show: + label_show_save_dir = os.path.join(save_dir, "labels_show/"+image_name.split(".")[0]) + if not os.path.exists(label_show_save_dir): os.makedirs(label_show_save_dir) + + target_w, target_h = TARGET_W, TARGET_H + overlap = target_h // 8 # 128 + stride = target_h - overlap # 896 + + image = np.asarray(Image.open(image_path)) + label = np.asarray(Image.open(label_path)) + image = cv.cvtColor(image,cv.COLOR_RGB2BGR) + + h, w = image.shape[0], image.shape[1] + print("原始大小: ", w, h) + if (w-target_w) % stride: + new_w = ((w-target_w)//stride + 1)*stride + target_w + if (h-target_h) % stride: + new_h = ((h-target_h)//stride + 1)*stride + target_h + image = cv.copyMakeBorder(image,0,new_h-h,0,new_w-w,cv.BORDER_CONSTANT,value=[0,0,0]) + label = cv.copyMakeBorder(label,0,new_h-h,0,new_w-w,cv.BORDER_CONSTANT,value=1) + h, w = image.shape[0], image.shape[1] + print("填充至整数倍: ", w, h) + + def crop(cnt, crop_image, crop_label, is_show=is_show): + _name = image_name.split(".")[0] + image_save_path = os.path.join(image_save_dir, _name+"_"+str(cnt[0])+"_"+str(cnt[1])+".png") + label_save_path = os.path.join(label_save_dir, _name+"_"+str(cnt[0])+"_"+str(cnt[1])+".png") + label_show_save_path = os.path.join(label_show_save_dir, _name+"_"+str(cnt[0])+"_"+str(cnt[1])+".png") + cv.imwrite(image_save_path, crop_image) + cv.imwrite(label_save_path, crop_label) + if is_show: + cv.imwrite(label_show_save_path, crop_label*255) + + h, w = image.shape[0], image.shape[1] + cnt = 0 + for i in tqdm(range((w-target_w)//stride + 1)): + for j in range((h-target_h)//stride + 1): + topleft_x = i*stride + topleft_y = j*stride + crop_image = image[topleft_y:topleft_y+target_h,topleft_x:topleft_x+target_w] + crop_label = label[topleft_y:topleft_y+target_h,topleft_x:topleft_x+target_w] + if np.sum(crop_image) != 0: + crop((i, j),crop_image,crop_label) + cnt += 1 + print(cnt) + # os.remove(image_path) + + +def get_train_val(): + file_train = open('./data/train.txt', 'w') + file_val = open('./data/val.txt', 'w') + + image_list_382 = os.listdir('./data/images/382') + image_list_182 = os.listdir('./data/images/182') + + print(len(image_list_182)) + print(len(image_list_382)) + + random.shuffle(image_list_382) + random.shuffle(image_list_182) + + for ele in image_list_382: + if random.randint(0, 10) < 2: # 8:2 + file_val.write(str(ele) + '\n') + else: + file_train.write(str(ele) + '\n') + + for ele in image_list_182: + if random.randint(0, 10) < 2: # 8:2 + file_val.write(str(ele) + '\n') + else: + file_train.write(str(ele) + '\n') + + file_train.close() + file_val.close() + + +if __name__ == "__main__": + data_dir = "./data" + img_name1 = "382.png" + img_name2 = "182.png" + label_name1 = "382_label.png" + label_name2 = "182_label.png" + cut_images(img_name1, os.path.join(data_dir, img_name1), os.path.join(data_dir, label_name1), data_dir) + cut_images(img_name2, os.path.join(data_dir, img_name2), os.path.join(data_dir, label_name2), data_dir) + get_train_val() diff --git a/data/edge_utils.py b/data/edge_utils.py new file mode 100644 index 0000000..ebbbe2d --- /dev/null +++ b/data/edge_utils.py @@ -0,0 +1,84 @@ +import cv2 +import os +import numpy as np +import torch +import torch.nn as nn +from tqdm import tqdm +from scipy.ndimage.morphology import distance_transform_edt + + +def onehot_to_multiclass_edges(mask, radius, num_classes): + """ + Converts a segmentation mask (K,H,W) to an edgemap (K,H,W) + + """ + if radius < 0: + return mask + + # We need to pad the borders for boundary conditions + mask_pad = np.pad(mask, ((0, 0), (1, 1), (1, 1)), mode='constant', constant_values=0) + + channels = [] + for i in range(num_classes): + dist = distance_transform_edt(mask_pad[i, :]) + distance_transform_edt(1.0 - mask_pad[i, :]) + dist = dist[1:-1, 1:-1] + dist[dist > radius] = 0 + dist = (dist > 0).astype(np.uint8) + channels.append(dist) + + return np.array(channels) + + +def onehot_to_binary_edges(mask, radius, num_classes): + """ + Converts a segmentation mask (K,H,W) to a binary edgemap (H,W) + + """ + + if radius < 0: + return mask + + # We need to pad the borders for boundary conditions + mask_pad = np.pad(mask, ((0, 0), (1, 1), (1, 1)), mode='constant', constant_values=0) + + edgemap = np.zeros(mask.shape[1:]) + for i in range(num_classes): + # ti qu lun kuo + dist = distance_transform_edt(mask_pad[i, :]) + distance_transform_edt(1.0 - mask_pad[i, :]) + dist = dist[1:-1, 1:-1] + dist[dist > radius] = 0 + edgemap += dist + # edgemap = np.expand_dims(edgemap, axis=0) + edgemap = (edgemap > 0).astype(np.uint8)*255 + return edgemap + + +def mask_to_onehot(mask, num_classes): + """ + Converts a segmentation mask (H,W) to (K,H,W) where the last dim is a one + hot encoding vector + + """ + _mask = [mask == (i) for i in range(num_classes)] + return np.array(_mask).astype(np.uint8) + +if __name__ == '__main__': + label = cv2.imread('/media/ws/新加卷1/wy/dataset/HUAWEI/data/labels/182/182_16_23.png',0) + img = cv2.imread('/media/ws/新加卷1/wy/dataset/HUAWEI/data/images/182/182_16_23.png') + oneHot_label = mask_to_onehot(label, 2) + edge = onehot_to_binary_edges(oneHot_label, 2, 2) # #edge=255,background=0 + edge[:2, :] = 0 + edge[-2:, :] = 0 + edge[:, :2] = 0 + edge[:, -2:] = 0 + print(edge) + print(np.unique(edge)) + print(edge.shape) + cv2.imwrite('test.png',edge) + cv2.namedWindow('1',0) + cv2.namedWindow('2',0) + cv2.namedWindow('3',0) + cv2.imshow('1',label*255) + cv2.imshow('2',edge) + cv2.imshow('3',img) + cv2.waitKey() \ No newline at end of file diff --git a/data/make_data.py b/data/make_data.py new file mode 100644 index 0000000..e45ce52 --- /dev/null +++ b/data/make_data.py @@ -0,0 +1,239 @@ +import cv2 +import pdb +import collections +import matplotlib.pyplot as plt +import numpy as np +import os +import os.path as osp +from PIL import Image, ImageOps, ImageFilter +import random +import torch +import torchvision +from torch.utils import data +import torchvision.transforms as transforms +from .edge_utils import * +class GaofenTrain(data.Dataset): + def __init__(self, root, list_path, crop_size=(640, 640), + scale=True, mirror=True,rotation=True, bright=False, ignore_label=1, use_aug=True, network='resnet101'): + self.root = root + self.src_h = 1024 + self.src_w = 1024 + self.list_path = list_path + self.crop_h, self.crop_w = crop_size + self.bright = bright + self.scale = scale + self.ignore_label = ignore_label + self.is_mirror = mirror + self.rotation = rotation + self.use_aug = use_aug + self.img_ids = [i_id.strip() for i_id in open(list_path)] + self.files = [] + self.network = network + for item in self.img_ids: + image_path = 'images/'+item.split('_')[0]+'/'+item + label_path = 'labels/'+item.split('_')[0]+'/'+item + name = item + img_file = osp.join(self.root, image_path) + label_file = osp.join(self.root, label_path) + self.files.append({ + "img": img_file, + "label": label_file, + "name": name, + "weight": 1 + }) + + print('{} images are loaded!'.format(len(self.img_ids))) + + def __len__(self): + return len(self.files) + + def random_brightness(self, img): + if random.random() < 0.5: + return img + self.shift_value = 10 #取自HRNet + img = img.astype(np.float32) + shift = random.randint(-self.shift_value, self.shift_value) + img[:, :, :] += shift + img = np.around(img) + img = np.clip(img, 0, 255).astype(np.uint8) + return img + + def generate_scale_label(self, image, label): + f_scale = 0.5 + random.randint(0, 11) / 10.0 # [0.5, 1.5] + image = cv2.resize(image, None, fx=f_scale, fy=f_scale, interpolation=cv2.INTER_LINEAR) + label = cv2.resize(label, None, fx=f_scale, fy=f_scale, interpolation=cv2.INTER_NEAREST) + return image, label + + def __getitem__(self, index): + datafiles = self.files[index] + image = cv2.imread(datafiles["img"], cv2.IMREAD_COLOR) + label = cv2.imread(datafiles["label"],0) + #旋转90/180/270 + if self.rotation and random.random() > 0.5: + angel = np.random.randint(1,4) + M = cv2.getRotationMatrix2D(((self.src_h - 1) / 2., (self.src_w - 1) / 2.), 90*angel, 1) + image = cv2.warpAffine(image, M, (self.src_h, self.src_w), flags=cv2.INTER_LINEAR) + label = cv2.warpAffine(label, M, (self.src_h, self.src_w), flags=cv2.INTER_NEAREST, borderValue=self.ignore_label) + # 旋转-30-30 + if self.rotation and random.random() > 0.5: + angel = np.random.randint(-30,30) + M = cv2.getRotationMatrix2D(((self.src_h - 1) / 2., (self.src_w - 1) / 2.), angel, 1) + image = cv2.warpAffine(image, M, (self.src_h, self.src_w), flags=cv2.INTER_LINEAR) + label = cv2.warpAffine(label, M, (self.src_h, self.src_w), flags=cv2.INTER_NEAREST, borderValue=self.ignore_label) + size = image.shape + if self.scale: #尺度变化 + image, label = self.generate_scale_label(image, label) + if self.bright: #亮度变化 + image = self.random_brightness(image) + image = np.asarray(image, np.float32) + image = image[:, :, ::-1] + mean = (0.355403, 0.383969, 0.359276) + std = (0.206617, 0.202157, 0.210082) + image /= 255. + image -= mean + image /= std + + img_h, img_w = label.shape + pad_h = max(self.crop_h - img_h, 0) + pad_w = max(self.crop_w - img_w, 0) + if pad_h > 0 or pad_w > 0: + img_pad = cv2.copyMakeBorder(image, 0, pad_h, 0, + pad_w, cv2.BORDER_CONSTANT, + value=(0.0, 0.0, 0.0)) + label_pad = cv2.copyMakeBorder(label, 0, pad_h, 0, + pad_w, cv2.BORDER_CONSTANT, + value=(self.ignore_label,)) #边界填充的是ignore + else: + img_pad, label_pad = image, label + + img_h, img_w = label_pad.shape + h_off = random.randint(0, img_h - self.crop_h) + w_off = random.randint(0, img_w - self.crop_w) + image = np.asarray(img_pad[h_off: h_off + self.crop_h, w_off: w_off + self.crop_w], np.float32) + label = np.asarray(label_pad[h_off: h_off + self.crop_h, w_off: w_off + self.crop_w], np.float32) + image = image.transpose((2, 0, 1)) # 3XHXW + + if self.is_mirror: #水平/垂直翻转 + flip1 = np.random.choice(2) * 2 - 1 + image = image[:, :, ::flip1] + label = label[:, ::flip1] + flip2 = np.random.choice(2) * 2 - 1 + image = image[:,::flip2, :] + label = label[::flip2,:] + oneHot_label = mask_to_onehot(label,2) #edge=255,background=0 + edge = onehot_to_binary_edges(oneHot_label,2,2) + # 消去图像边缘 + edge[:2, :] = 0 + edge[-2:, :] = 0 + edge[:, :2] = 0 + edge[:, -2:] = 0 + return image.copy(), label.copy(), edge,np.array(size), datafiles + +class GaofenVal(data.Dataset): + def __init__(self, root, list_path, max_iters=None, crop_size=(321, 321), + scale=False, mirror=False, ignore_label=255, use_aug=True, network="renset101"): + self.root = root + self.list_path = list_path + self.crop_h, self.crop_w = crop_size + self.scale = scale + self.ignore_label = ignore_label + self.is_mirror = mirror + self.use_aug = use_aug + self.img_ids = [i_id.strip() for i_id in open(list_path)] + self.files = [] + self.network = network + for item in self.img_ids: + image_path = 'images/'+item.split('_')[0]+'/'+item + label_path = 'labels/'+item.split('_')[0]+'/'+item + name = item + img_file = osp.join(self.root, image_path) + label_file = osp.join(self.root, label_path) + self.files.append({ + "img": img_file, + "label": label_file, + "name": name, + "weight": 1 + }) + self.id_to_trainid = {} + + print('{} images are loaded!'.format(len(self.img_ids))) + + def __len__(self): + return len(self.files) + + + def __getitem__(self, index): + datafiles = self.files[index] + image = cv2.imread(datafiles["img"], cv2.IMREAD_COLOR) + label = cv2.imread(datafiles["label"],0) + + size = image.shape + + image = np.asarray(image, np.float32) + image = image[:, :, ::-1] + mean = (0.355403, 0.383969, 0.359276) + std = (0.206617, 0.202157, 0.210082) + image /= 255. + image -= mean + image /= std + + img_h, img_w = label.shape + + image = np.asarray(image, np.float32) + label = np.asarray(label, np.float32) + image = image.transpose((2, 0, 1)) # 3XHXW + return image.copy(), label.copy(),label.copy(), np.array(size), datafiles + + + +class GaofenSubmit(data.Dataset): + def __init__(self, root, list_path, max_iters=None, crop_size=(321, 321), + scale=False, mirror=False, ignore_label=255, use_aug=True, network="renset101"): + self.root = root + self.list_path = list_path + self.crop_h, self.crop_w = crop_size + self.scale = scale + self.ignore_label = ignore_label + self.is_mirror = mirror + self.use_aug = use_aug + self.img_ids = [i_id.strip() for i_id in open(list_path)] + self.files = [] + self.network = network + for item in self.img_ids: + image_path = 'images/'+item + label_path = 'labels/'+item[:-4]+'.png' + name = item + img_file = osp.join(self.root, image_path) + label_file = osp.join(self.root, label_path) + self.files.append({ + "img": img_file, + "label": label_file, + "name": name, + "weight": 1 + }) + self.id_to_trainid = {} + + print('{} images are loaded!'.format(len(self.img_ids))) + + def __len__(self): + return len(self.files) + + + def __getitem__(self, index): + datafiles = self.files[index] + image = cv2.imread(datafiles["img"], cv2.IMREAD_COLOR) + size = image.shape + name = datafiles["name"] + + image = np.asarray(image, np.float32) + image = image[:, :, ::-1] + mean = (0.355403, 0.383969, 0.359276) + std = (0.206617, 0.202157, 0.210082) + image /= 255. + image -= mean + image /= std + + image = np.asarray(image, np.float32) + image = image.transpose((2, 0, 1)) # 3XHXW + return image.copy(), np.array(size), name + diff --git a/evaluate.py b/evaluate.py new file mode 100644 index 0000000..9cfef8c --- /dev/null +++ b/evaluate.py @@ -0,0 +1,53 @@ +import numpy as np + + +class Evaluator(object): + def __init__(self, num_class): + self.num_class = num_class + self.confusion_matrix = np.zeros((self.num_class,)*2) + + def Pixel_Accuracy(self): + Acc = np.diag(self.confusion_matrix).sum() / self.confusion_matrix.sum() + return Acc + + def Pixel_Accuracy_Class(self): + # print(self.confusion_matrix.sum(axis=1)) + # print(self.confusion_matrix.sum()) + Acc = np.diag(self.confusion_matrix) / self.confusion_matrix.sum(axis=1) + Acc = np.nanmean(Acc) + return Acc + + def Mean_Intersection_over_Union(self): + MIoU = np.diag(self.confusion_matrix) / ( + np.sum(self.confusion_matrix, axis=1) + np.sum(self.confusion_matrix, axis=0) - + np.diag(self.confusion_matrix)) + MIoU = np.nanmean(MIoU) + return MIoU + def Mean_Intersection_over_Union_test(self): + MIoU = np.diag(self.confusion_matrix) / ( + np.sum(self.confusion_matrix, axis=1) + np.sum(self.confusion_matrix, axis=0) - + np.diag(self.confusion_matrix)) + return MIoU + + def Frequency_Weighted_Intersection_over_Union(self): + freq = np.sum(self.confusion_matrix, axis=1) / np.sum(self.confusion_matrix) + iu = np.diag(self.confusion_matrix) / ( + np.sum(self.confusion_matrix, axis=1) + np.sum(self.confusion_matrix, axis=0) - + np.diag(self.confusion_matrix)) + + FWIoU = (freq[freq > 0] * iu[freq > 0]).sum() + return FWIoU + + def _generate_matrix(self, gt_image, pre_image): + mask = (gt_image >= 0) & (gt_image < self.num_class) + label = self.num_class * gt_image[mask].astype('int') + pre_image[mask] + count = np.bincount(label, minlength=self.num_class**2) + confusion_matrix = count.reshape(self.num_class, self.num_class) + return confusion_matrix + + def add_batch(self, gt_image, pre_image): + assert gt_image.shape == pre_image.shape + self.confusion_matrix += self._generate_matrix(gt_image, pre_image) + + def reset(self): + self.confusion_matrix = np.zeros((self.num_class,) * 2) diff --git a/loss.py b/loss.py new file mode 100644 index 0000000..5de6e78 --- /dev/null +++ b/loss.py @@ -0,0 +1,476 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +import numpy as np +from collections import Counter +from collections import defaultdict +from torch.autograd import Variable + + +nSamples = [] +weights = [1 - (x / sum(nSamples)) for x in nSamples] +weights = torch.FloatTensor(weights).cuda() + + +def isnan(x): + return x != x + + +def mean(l, ignore_nan=False, empty=0): + """ + nanmean compatible with generators. + """ + l = iter(l) + if ignore_nan: + l = ifilterfalse(isnan, l) + try: + n = 1 + acc = next(l) + except StopIteration: + if empty == 'raise': + raise ValueError('Empty mean') + return empty + for n, v in enumerate(l, 2): + acc += v + if n == 1: + return acc + return acc / n + + +def lovasz_grad(gt_sorted): + """ + Computes gradient of the Lovasz extension w.r.t sorted errors + See Alg. 1 in paper + """ + p = len(gt_sorted) + gts = gt_sorted.sum() + intersection = gts - gt_sorted.float().cumsum(0) + union = gts + (1 - gt_sorted).float().cumsum(0) + jaccard = 1. - intersection / union + if p > 1: # cover 1-pixel case + jaccard[1:p] = jaccard[1:p] - jaccard[0:-1] + return jaccard + + +def iou_binary(preds, labels, EMPTY=1., ignore=None, per_image=True): + """ + IoU for foreground class + binary: 1 foreground, 0 background + """ + if not per_image: + preds, labels = (preds,), (labels,) + ious = [] + for pred, label in zip(preds, labels): + intersection = ((label == 1) & (pred == 1)).sum() + union = ((label == 1) | ((pred == 1) & (label != ignore))).sum() + if not union: + iou = EMPTY + else: + iou = float(intersection) / float(union) + ious.append(iou) + iou = mean(ious) # mean accross images if per_image + return 100 * iou + + +def iou(preds, labels, C, EMPTY=1., ignore=None, per_image=False): + """ + Array of IoU for each (non ignored) class + """ + if not per_image: + preds, labels = (preds,), (labels,) + ious = [] + for pred, label in zip(preds, labels): + iou = [] + for i in range(C): + if i != ignore: # The ignored label is sometimes among predicted classes (ENet - CityScapes) + intersection = ((label == i) & (pred == i)).sum() + union = ((label == i) | ((pred == i) & (label != ignore))).sum() + if not union: + iou.append(EMPTY) + else: + iou.append(float(intersection) / float(union)) + ious.append(iou) + ious = [mean(iou) for iou in zip(*ious)] # mean accross images if per_image + return 100 * np.array(ious) + + +# --------------------------- BINARY LOSSES --------------------------- + +def lovasz_hinge(logits, labels, per_image=True, ignore=None): + """ + Binary Lovasz hinge loss + logits: [B, H, W] Variable, logits at each pixel (between -\infty and +\infty) + labels: [B, H, W] Tensor, binary ground truth masks (0 or 1) + per_image: compute the loss per image instead of per batch + ignore: void class id + """ + if per_image: + loss = mean(lovasz_hinge_flat(*flatten_binary_scores(log.unsqueeze(0), lab.unsqueeze(0), ignore)) + for log, lab in zip(logits, labels)) + else: + loss = lovasz_hinge_flat(*flatten_binary_scores(logits, labels, ignore)) + return loss + + +def lovasz_hinge_flat(logits, labels): + """ + Binary Lovasz hinge loss + logits: [P] Variable, logits at each prediction (between -\infty and +\infty) + labels: [P] Tensor, binary ground truth labels (0 or 1) + ignore: label to ignore + """ + if len(labels) == 0: + # only void pixels, the gradients should be 0 + return logits.sum() * 0. + signs = 2. * labels.float() - 1. + errors = (1. - logits * Variable(signs)) + errors_sorted, perm = torch.sort(errors, dim=0, descending=True) + perm = perm.data + gt_sorted = labels[perm] + grad = lovasz_grad(gt_sorted) + loss = torch.dot(F.relu(errors_sorted), Variable(grad)) + return loss + + +def flatten_binary_scores(scores, labels, ignore=None): + """ + Flattens predictions in the batch (binary case) + Remove labels equal to 'ignore' + """ + scores = scores.view(-1) + labels = labels.view(-1) + if ignore is None: + return scores, labels + valid = (labels != ignore) + vscores = scores[valid] + vlabels = labels[valid] + return vscores, vlabels + + +class StableBCELoss(torch.nn.modules.Module): + def __init__(self): + super(StableBCELoss, self).__init__() + def forward(self, input, target): + neg_abs = - input.abs() + loss = input.clamp(min=0) - input * target + (1 + neg_abs.exp()).log() + return loss.mean() + + +def binary_xloss(logits, labels, ignore=None): + """ + Binary Cross entropy loss + logits: [B, H, W] Variable, logits at each pixel (between -\infty and +\infty) + labels: [B, H, W] Tensor, binary ground truth masks (0 or 1) + ignore: void class id + """ + logits, labels = flatten_binary_scores(logits, labels, ignore) + loss = StableBCELoss()(logits, Variable(labels.float())) + return loss + +# --------------------------- MULTICLASS LOSSES --------------------------- +def lovasz_softmax(probas, labels, classes='present', per_image=False, ignore=None): + """ + Multi-class Lovasz-Softmax loss + probas: [B, C, H, W] Variable, class probabilities at each prediction (between 0 and 1). + Interpreted as binary (sigmoid) output with outputs of size [B, H, W]. + labels: [B, H, W] Tensor, ground truth labels (between 0 and C - 1) + classes: 'all' for all, 'present' for classes present in labels, or a list of classes to average. + per_image: compute the loss per image instead of per batch + ignore: void class labels + """ + if per_image: + loss = mean(lovasz_softmax_flat(*flatten_probas(prob.unsqueeze(0), lab.unsqueeze(0), ignore), classes=classes) + for prob, lab in zip(probas, labels)) + else: + loss = lovasz_softmax_flat(*flatten_probas(probas, labels, ignore), classes=classes) + return loss + + +def lovasz_softmax_flat(probas, labels, classes='present'): + """ + Multi-class Lovasz-Softmax loss + probas: [P, C] Variable, class probabilities at each prediction (between 0 and 1) + labels: [P] Tensor, ground truth labels (between 0 and C - 1) + classes: 'all' for all, 'present' for classes present in labels, or a list of classes to average. + """ + if probas.numel() == 0: + # only void pixels, the gradients should be 0 + return probas * 0. + C = probas.size(1) + losses = [] + class_to_sum = list(range(C)) if classes in ['all', 'present'] else classes + for c in class_to_sum: + fg = (labels == c).float() # foreground for class c + if (classes is 'present' and fg.sum() == 0): + continue + if C == 1: + if len(classes) > 1: + raise ValueError('Sigmoid output possible only with 1 class') + class_pred = probas[:, 0] + else: + class_pred = probas[:, c] + errors = (Variable(fg) - class_pred).abs() + errors_sorted, perm = torch.sort(errors, 0, descending=True) + perm = perm.data + fg_sorted = fg[perm] + losses.append(torch.dot(errors_sorted, Variable(lovasz_grad(fg_sorted)))) + return mean(losses) + + +def flatten_probas(probas, labels, ignore=None): + """ + Flattens predictions in the batch + """ + if probas.dim() == 3: + # assumes output of a sigmoid layer + B, H, W = probas.size() + probas = probas.view(B, 1, H, W) + B, C, H, W = probas.size() + probas = probas.permute(0, 2, 3, 1).contiguous().view(-1, C) # B * H * W, C = P, C + labels = labels.view(-1) + if ignore is None: + return probas, labels + valid = (labels != ignore) + vprobas = probas[valid.nonzero().squeeze()] + vlabels = labels[valid] + return vprobas, vlabels + + +def make_one_hot(input, num_classes): + """Convert class index tensor to one hot encoding tensor. + Args: + input: A tensor of shape [N, 1, *] + num_classes: An int of number of class + Returns: + A tensor of shape [N, num_classes, *] + """ + input=input.unsqueeze(1) + shape = np.array(input.shape) + shape[1] = num_classes + shape = tuple(shape) + result = torch.zeros(shape) + result = result.scatter_(1, input.cpu(), 1) + return result + +class BinaryDiceLoss(nn.Module): + """Dice loss of binary class + Args: + smooth: A float number to smooth loss, and avoid NaN error, default: 1 + p: Denominator value: \sum{x^p} + \sum{y^p}, default: 2 + predict: A tensor of shape [N, *] + target: A tensor of shape same with predict + reduction: Reduction method to apply, return mean over batch if 'mean', + return sum if 'sum', return a tensor of shape [N,] if 'none' + Returns: + Loss tensor according to arg reduction + Raise: + Exception if unexpected reduction + """ + def __init__(self, smooth=1, p=2, reduction='mean'): + super(BinaryDiceLoss, self).__init__() + self.smooth = smooth + self.p = p + self.reduction = reduction + + def forward(self, predict, target): + assert predict.shape[0] == target.shape[0], "predict & target batch size don't match" + predict = predict.contiguous().view(predict.shape[0], -1) + target = target.contiguous().view(target.shape[0], -1) + + num = torch.sum(torch.mul(predict, target), dim=1) + self.smooth + den = torch.sum(predict.pow(self.p) + target.pow(self.p), dim=1) + self.smooth + + loss = 1 - num / den + + if self.reduction == 'mean': + return loss.mean() + elif self.reduction == 'sum': + return loss.sum() + elif self.reduction == 'none': + return loss + else: + raise Exception('Unexpected reduction {}'.format(self.reduction)) + + +class DiceLoss(nn.Module): + """Dice loss, need one hot encode input + Args: + weight: An array of shape [num_classes,] + ignore_index: class index to ignore + predict: A tensor of shape [N, C, *] + target: A tensor of same shape with predict + other args pass to BinaryDiceLoss + Return: + same as BinaryDiceLoss + """ + def __init__(self, weight=None, ignore_index=None, **kwargs): + super(DiceLoss, self).__init__() + self.kwargs = kwargs + self.weight = weight + self.ignore_index = ignore_index + + def forward(self, predict, target): + assert predict.shape == target.shape, 'predict & target shape do not match' + dice = BinaryDiceLoss(**self.kwargs) + total_loss = 0 + predict = F.softmax(predict, dim=1) + + for i in range(target.shape[1]): + if i != self.ignore_index: + dice_loss = dice(predict[:, i], target[:, i]) + if self.weight is not None: + assert self.weight.shape[0] == target.shape[1], \ + 'Expect weight shape [{}], get[{}]'.format(target.shape[1], self.weight.shape[0]) + dice_loss *= self.weights[i] + total_loss += dice_loss + + return total_loss/target.shape[1] + + +class OhemCrossEntropy(nn.Module): + def __init__(self, ignore_label=-1, thres=0.7, + min_kept=100000, weight=None): + super(OhemCrossEntropy, self).__init__() + self.thresh = thres + self.min_kept = max(1, min_kept) + self.ignore_label = ignore_label + self.criterion = nn.CrossEntropyLoss( + weight=weight, + ignore_index=ignore_label, + reduction='none' + ) + + def _ce_forward(self, score, target): + ph, pw = score.size(2), score.size(3) + h, w = target.size(1), target.size(2) + if ph != h or pw != w: + score = F.interpolate(input=score, size=( + h, w), mode='bilinear', align_corners=True) + + loss = self.criterion(score, target) + + return loss + + def _ohem_forward(self, score, target, **kwargs): + ph, pw = score.size(2), score.size(3) + h, w = target.size(1), target.size(2) + if ph != h or pw != w: + score = F.interpolate(input=score, size=( + h, w), mode='bilinear', align_corners=True) + pred = F.softmax(score, dim=1) + pixel_losses = self.criterion(score, target).contiguous().view(-1) + mask = target.contiguous().view(-1) != self.ignore_label + + tmp_target = target.clone() + tmp_target[tmp_target == self.ignore_label] = 0 + pred = pred.gather(1, tmp_target.unsqueeze(1)) + pred, ind = pred.contiguous().view(-1,)[mask].contiguous().sort() + min_value = pred[min(self.min_kept, pred.numel() - 1)] + threshold = max(min_value, self.thresh) + + pixel_losses = pixel_losses[mask][ind] + pixel_losses = pixel_losses[pred < threshold] + return pixel_losses.mean() + + def forward(self, score, target): + score = [score] + weights = [1.] + assert len(weights) == len(score) + + functions = [self._ce_forward] * (len(weights) - 1) + [self._ohem_forward] + return sum([ + w * func(x, target) + for (w, x, func) in zip(weights, score, functions) + ]) + +class SmoothCrossEntropy(nn.Module): + def __init__(self, ignore_index=255,eps=0.1): + super(SmoothCrossEntropy, self).__init__() + self.eps = eps + self.ignore_label = ignore_index + def forward(self, score, target): + pred = F.softmax(score, dim=1) #nxcxhxw + + mask = target != self.ignore_label + tmp_target = target.clone() + tmp_target[tmp_target == self.ignore_label] = 0 + + + one_hot_labels = torch.zeros([score.shape[0], 9, score.shape[2], score.shape[3]]).cuda() + one_hot_labels.scatter_(1, tmp_target.unsqueeze(1), 1) + K = 9 # number of class + smooth_label = (1 - self.eps) * one_hot_labels + self.eps / (K) #nxcxhxw + loss = torch.sum(torch.mul(-1.*smooth_label,torch.log(pred)),dim=1) + return loss[mask].mean() +# +# def calc_loss(pred, target,metrics): +# criters=nn.CrossEntropyLoss(ignore_index=255) +# ce_loss = criters(pred,target) +# loss = ce_loss +# metrics['loss'] += loss.data.cpu().numpy() +# # metrics['ce_loss'] += 0 +# # metrics['ls_loss'] += 0 +# return loss +class CrossEntropy(nn.Module): + def __init__(self, ignore_label=255, weight=None): + super(CrossEntropy, self).__init__() + self.ignore_label = ignore_label + self.criterion = nn.CrossEntropyLoss( + weight=weight, + reduction='none' + ) + + def _forward(self, score, target): + ph, pw = score.size(2), score.size(3) + h, w = target.size(1), target.size(2) + if ph != h or pw != w: + score = F.interpolate(input=score, size=( + h, w), mode='bilinear', align_corners=True) + + loss = self.criterion(score, target) + + return loss + + def forward(self, score, target): + + hr_weights = [0.4,1] + assert len(hr_weights) == len(score) + loss = hr_weights[0]*self._forward(score[0], target) + hr_weights[1]*self._forward(score[1], target) + return loss + + +def calc_loss(pred, target, edge, metrics): + edge_weight = 4. + criters_ce = CrossEntropy() + loss_ce = criters_ce(pred,target) + loss_ls = lovasz_softmax(F.softmax(pred[1],dim=1),target) + edge[edge == 0] = 1. + edge[edge == 255] = edge_weight + loss_ce *= edge + loss_ce_,ind = loss_ce.contiguous().view(-1).sort() + min_value = loss_ce_[int(0.5*loss_ce.shape[0]*loss_ce.shape[1]*loss_ce.shape[2])] + #print(loss_ce.shape) + loss_ce = loss_ce[loss_ce>min_value] + #print(loss_ce.shape) + loss_ce = loss_ce.mean() + loss = loss_ce + loss_ls + metrics['loss'] += loss.data.cpu().numpy() + metrics['ce_loss'] += loss_ce.data.cpu().numpy() + metrics['ls_loss'] += loss_ls.data.cpu().numpy() + return loss + +def calc_smoothloss(pred, target,metrics): + criters=SmoothCrossEntropy(ignore_index=255) + loss=criters(pred,target) + metrics['loss'] += loss.data.cpu().numpy() + return loss + + +if __name__ == '__main__': + criter = BinaryDiceLoss() + target=torch.ones((4,256,256),dtype=torch.long) + input=(torch.ones((4,256,256))*0.9) + loss=criter(input,target) + print(loss) + + diff --git a/model/__pycache__/hrnet_config.cpython-35.pyc b/model/__pycache__/hrnet_config.cpython-35.pyc new file mode 100644 index 0000000..9154a5e Binary files /dev/null and b/model/__pycache__/hrnet_config.cpython-35.pyc differ diff --git a/model/__pycache__/hrnet_config.cpython-36.pyc b/model/__pycache__/hrnet_config.cpython-36.pyc new file mode 100644 index 0000000..1064823 Binary files /dev/null and b/model/__pycache__/hrnet_config.cpython-36.pyc differ diff --git a/model/__pycache__/hrnet_config.cpython-37.pyc b/model/__pycache__/hrnet_config.cpython-37.pyc new file mode 100644 index 0000000..27b9c9c Binary files /dev/null and b/model/__pycache__/hrnet_config.cpython-37.pyc differ diff --git a/model/__pycache__/seg_hrnet.cpython-35.pyc b/model/__pycache__/seg_hrnet.cpython-35.pyc new file mode 100644 index 0000000..9ea0576 Binary files /dev/null and b/model/__pycache__/seg_hrnet.cpython-35.pyc differ diff --git a/model/__pycache__/seg_hrnet.cpython-36.pyc b/model/__pycache__/seg_hrnet.cpython-36.pyc new file mode 100644 index 0000000..fd0c317 Binary files /dev/null and b/model/__pycache__/seg_hrnet.cpython-36.pyc differ diff --git a/model/__pycache__/seg_hrnet.cpython-37.pyc b/model/__pycache__/seg_hrnet.cpython-37.pyc new file mode 100644 index 0000000..f9d790b Binary files /dev/null and b/model/__pycache__/seg_hrnet.cpython-37.pyc differ diff --git a/model/hrnet_config.py b/model/hrnet_config.py new file mode 100644 index 0000000..6edfcc5 --- /dev/null +++ b/model/hrnet_config.py @@ -0,0 +1,130 @@ +# ------------------------------------------------------------------------------ +# Copyright (c) Microsoft +# Licensed under the MIT License. +# Create by Bin Xiao (Bin.Xiao@microsoft.com) +# Modified by Ke Sun (sunk@mail.ustc.edu.cn), Rainbowsecret (yuyua@microsoft.com) +# ------------------------------------------------------------------------------ + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from yacs.config import CfgNode as CN + + +# configs for HRNet48 +HRNET_48 = CN() +HRNET_48.FINAL_CONV_KERNEL = 1 + +HRNET_48.STAGE1 = CN() +HRNET_48.STAGE1.NUM_MODULES = 1 +HRNET_48.STAGE1.NUM_BRANCHES = 1 +HRNET_48.STAGE1.NUM_BLOCKS = [4] +HRNET_48.STAGE1.NUM_CHANNELS = [64] +HRNET_48.STAGE1.BLOCK = 'BOTTLENECK' +HRNET_48.STAGE1.FUSE_METHOD = 'SUM' + +HRNET_48.STAGE2 = CN() +HRNET_48.STAGE2.NUM_MODULES = 1 +HRNET_48.STAGE2.NUM_BRANCHES = 2 +HRNET_48.STAGE2.NUM_BLOCKS = [4, 4] +HRNET_48.STAGE2.NUM_CHANNELS = [48, 96] +HRNET_48.STAGE2.BLOCK = 'BASIC' +HRNET_48.STAGE2.FUSE_METHOD = 'SUM' + +HRNET_48.STAGE3 = CN() +HRNET_48.STAGE3.NUM_MODULES = 4 +HRNET_48.STAGE3.NUM_BRANCHES = 3 +HRNET_48.STAGE3.NUM_BLOCKS = [4, 4, 4] +HRNET_48.STAGE3.NUM_CHANNELS = [48, 96, 192] +HRNET_48.STAGE3.BLOCK = 'BASIC' +HRNET_48.STAGE3.FUSE_METHOD = 'SUM' + +HRNET_48.STAGE4 = CN() +HRNET_48.STAGE4.NUM_MODULES = 3 +HRNET_48.STAGE4.NUM_BRANCHES = 4 +HRNET_48.STAGE4.NUM_BLOCKS = [4, 4, 4, 4] +HRNET_48.STAGE4.NUM_CHANNELS = [48, 96, 192, 384] +HRNET_48.STAGE4.BLOCK = 'BASIC' +HRNET_48.STAGE4.FUSE_METHOD = 'SUM' + + +# configs for HRNet32 +HRNET_32 = CN() +HRNET_32.FINAL_CONV_KERNEL = 1 + +HRNET_32.STAGE1 = CN() +HRNET_32.STAGE1.NUM_MODULES = 1 +HRNET_32.STAGE1.NUM_BRANCHES = 1 +HRNET_32.STAGE1.NUM_BLOCKS = [4] +HRNET_32.STAGE1.NUM_CHANNELS = [64] +HRNET_32.STAGE1.BLOCK = 'BOTTLENECK' +HRNET_32.STAGE1.FUSE_METHOD = 'SUM' + +HRNET_32.STAGE2 = CN() +HRNET_32.STAGE2.NUM_MODULES = 1 +HRNET_32.STAGE2.NUM_BRANCHES = 2 +HRNET_32.STAGE2.NUM_BLOCKS = [4, 4] +HRNET_32.STAGE2.NUM_CHANNELS = [32, 64] +HRNET_32.STAGE2.BLOCK = 'BASIC' +HRNET_32.STAGE2.FUSE_METHOD = 'SUM' + +HRNET_32.STAGE3 = CN() +HRNET_32.STAGE3.NUM_MODULES = 4 +HRNET_32.STAGE3.NUM_BRANCHES = 3 +HRNET_32.STAGE3.NUM_BLOCKS = [4, 4, 4] +HRNET_32.STAGE3.NUM_CHANNELS = [32, 64, 128] +HRNET_32.STAGE3.BLOCK = 'BASIC' +HRNET_32.STAGE3.FUSE_METHOD = 'SUM' + +HRNET_32.STAGE4 = CN() +HRNET_32.STAGE4.NUM_MODULES = 3 +HRNET_32.STAGE4.NUM_BRANCHES = 4 +HRNET_32.STAGE4.NUM_BLOCKS = [4, 4, 4, 4] +HRNET_32.STAGE4.NUM_CHANNELS = [32, 64, 128, 256] +HRNET_32.STAGE4.BLOCK = 'BASIC' +HRNET_32.STAGE4.FUSE_METHOD = 'SUM' + + +# configs for HRNet18 +HRNET_18 = CN() +HRNET_18.FINAL_CONV_KERNEL = 1 + +HRNET_18.STAGE1 = CN() +HRNET_18.STAGE1.NUM_MODULES = 1 +HRNET_18.STAGE1.NUM_BRANCHES = 1 +HRNET_18.STAGE1.NUM_BLOCKS = [4] +HRNET_18.STAGE1.NUM_CHANNELS = [64] +HRNET_18.STAGE1.BLOCK = 'BOTTLENECK' +HRNET_18.STAGE1.FUSE_METHOD = 'SUM' + +HRNET_18.STAGE2 = CN() +HRNET_18.STAGE2.NUM_MODULES = 1 +HRNET_18.STAGE2.NUM_BRANCHES = 2 +HRNET_18.STAGE2.NUM_BLOCKS = [4, 4] +HRNET_18.STAGE2.NUM_CHANNELS = [18, 36] +HRNET_18.STAGE2.BLOCK = 'BASIC' +HRNET_18.STAGE2.FUSE_METHOD = 'SUM' + +HRNET_18.STAGE3 = CN() +HRNET_18.STAGE3.NUM_MODULES = 4 +HRNET_18.STAGE3.NUM_BRANCHES = 3 +HRNET_18.STAGE3.NUM_BLOCKS = [4, 4, 4] +HRNET_18.STAGE3.NUM_CHANNELS = [18, 36, 72] +HRNET_18.STAGE3.BLOCK = 'BASIC' +HRNET_18.STAGE3.FUSE_METHOD = 'SUM' + +HRNET_18.STAGE4 = CN() +HRNET_18.STAGE4.NUM_MODULES = 3 +HRNET_18.STAGE4.NUM_BRANCHES = 4 +HRNET_18.STAGE4.NUM_BLOCKS = [4, 4, 4, 4] +HRNET_18.STAGE4.NUM_CHANNELS = [18, 36, 72, 144] +HRNET_18.STAGE4.BLOCK = 'BASIC' +HRNET_18.STAGE4.FUSE_METHOD = 'SUM' + + +MODEL_CONFIGS = { + 'hrnet18': HRNET_18, + 'hrnet32': HRNET_32, + 'hrnet48': HRNET_48, +} \ No newline at end of file diff --git a/model/seg_hrnet.py b/model/seg_hrnet.py new file mode 100644 index 0000000..d2a8590 --- /dev/null +++ b/model/seg_hrnet.py @@ -0,0 +1,750 @@ +""" +MIT License +Copyright (c) 2019 Microsoft +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. +""" +import os +import logging +import torch +import torch.nn as nn +import torch.nn.functional as F +import numpy as np +logger = logging.getLogger('hrnet_backbone') + +__all__ = ['hrnet18', 'hrnet32', 'hrnet48'] + + +model_urls = { + # all the checkpoints come from https://github.com/HRNet/HRNet-Image-Classification + 'hrnet18': 'https://opr0mq.dm.files.1drv.com/y4mIoWpP2n-LUohHHANpC0jrOixm1FZgO2OsUtP2DwIozH5RsoYVyv_De5wDgR6XuQmirMV3C0AljLeB-zQXevfLlnQpcNeJlT9Q8LwNYDwh3TsECkMTWXCUn3vDGJWpCxQcQWKONr5VQWO1hLEKPeJbbSZ6tgbWwJHgHF7592HY7ilmGe39o5BhHz7P9QqMYLBts6V7QGoaKrr0PL3wvvR4w', + 'hrnet32': 'https://opr74a.dm.files.1drv.com/y4mKOuRSNGQQlp6wm_a9bF-UEQwp6a10xFCLhm4bqjDu6aSNW9yhDRM7qyx0vK0WTh42gEaniUVm3h7pg0H-W0yJff5qQtoAX7Zze4vOsqjoIthp-FW3nlfMD0-gcJi8IiVrMWqVOw2N3MbCud6uQQrTaEAvAdNjtjMpym1JghN-F060rSQKmgtq5R-wJe185IyW4-_c5_ItbhYpCyLxdqdEQ', + 'hrnet48': 'https://optgaw.dm.files.1drv.com/y4mWNpya38VArcDInoPaL7GfPMgcop92G6YRkabO1QTSWkCbo7djk8BFZ6LK_KHHIYE8wqeSAChU58NVFOZEvqFaoz392OgcyBrq_f8XGkusQep_oQsuQ7DPQCUrdLwyze_NlsyDGWot0L9agkQ-M_SfNr10ETlCF5R7BdKDZdupmcMXZc-IE3Ysw1bVHdOH4l-XEbEKFAi6ivPUbeqlYkRMQ' +} + +class ModuleHelper: + + @staticmethod + def BNReLU(num_features, bn_type=None, **kwargs): + return nn.Sequential( + nn.BatchNorm2d(num_features, **kwargs), + nn.ReLU() + ) + + @staticmethod + def BatchNorm2d(*args, **kwargs): + return nn.BatchNorm2d +class SpatialGather_Module(nn.Module): + """ + Aggregate the context features according to the initial + predicted probability distribution. + Employ the soft-weighted method to aggregate the context. + """ + def __init__(self, cls_num=0, scale=1): + super(SpatialGather_Module, self).__init__() + self.cls_num = cls_num + self.scale = scale + + def forward(self, feats, probs): + batch_size, c, h, w = probs.size(0), probs.size(1), probs.size(2), probs.size(3) + probs = probs.view(batch_size, c, -1) + feats = feats.view(batch_size, feats.size(1), -1) + feats = feats.permute(0, 2, 1) # batch x hw x c + probs = F.softmax(self.scale * probs, dim=2)# batch x k x hw + # print(probs.shape,feats.shape) + ocr_context = torch.matmul(probs, feats)\ + .permute(0, 2, 1).unsqueeze(3)# batch x k x c + return ocr_context + + +class _ObjectAttentionBlock(nn.Module): + ''' + The basic implementation for object context block + Input: + N X C X H X W + Parameters: + in_channels : the dimension of the input feature map + key_channels : the dimension after the key/query transform + scale : choose the scale to downsample the input feature maps (save memory cost) + bn_type : specify the bn type + Return: + N X C X H X W + ''' + def __init__(self, + in_channels, + key_channels, + scale=1, + bn_type=None): + super(_ObjectAttentionBlock, self).__init__() + self.scale = scale + self.in_channels = in_channels + self.key_channels = key_channels + self.pool = nn.MaxPool2d(kernel_size=(scale, scale)) + self.f_pixel = nn.Sequential( + nn.Conv2d(in_channels=self.in_channels, out_channels=self.key_channels, + kernel_size=1, stride=1, padding=0, bias=False), + ModuleHelper.BNReLU(self.key_channels, bn_type=bn_type), + nn.Conv2d(in_channels=self.key_channels, out_channels=self.key_channels, + kernel_size=1, stride=1, padding=0, bias=False), + ModuleHelper.BNReLU(self.key_channels, bn_type=bn_type), + ) + self.f_object = nn.Sequential( + nn.Conv2d(in_channels=self.in_channels, out_channels=self.key_channels, + kernel_size=1, stride=1, padding=0, bias=False), + ModuleHelper.BNReLU(self.key_channels, bn_type=bn_type), + nn.Conv2d(in_channels=self.key_channels, out_channels=self.key_channels, + kernel_size=1, stride=1, padding=0, bias=False), + ModuleHelper.BNReLU(self.key_channels, bn_type=bn_type), + ) + self.f_down = nn.Sequential( + nn.Conv2d(in_channels=self.in_channels, out_channels=self.key_channels, + kernel_size=1, stride=1, padding=0, bias=False), + ModuleHelper.BNReLU(self.key_channels, bn_type=bn_type), + ) + self.f_up = nn.Sequential( + nn.Conv2d(in_channels=self.key_channels, out_channels=self.in_channels, + kernel_size=1, stride=1, padding=0, bias=False), + ModuleHelper.BNReLU(self.in_channels, bn_type=bn_type), + ) + + def forward(self, x, proxy): + batch_size, h, w = x.size(0), x.size(2), x.size(3) + if self.scale > 1: + x = self.pool(x) + + query = self.f_pixel(x).view(batch_size, self.key_channels, -1) + query = query.permute(0, 2, 1) + key = self.f_object(proxy).view(batch_size, self.key_channels, -1) + value = self.f_down(proxy).view(batch_size, self.key_channels, -1) + value = value.permute(0, 2, 1) + + sim_map = torch.matmul(query, key) + sim_map = (self.key_channels**-.5) * sim_map + sim_map = F.softmax(sim_map, dim=-1) + + # add bg context ... + context = torch.matmul(sim_map, value) + context = context.permute(0, 2, 1).contiguous() + context = context.view(batch_size, self.key_channels, *x.size()[2:]) + context = self.f_up(context) + if self.scale > 1: + context = F.interpolate(input=context, size=(h, w), mode='bilinear', align_corners=True) + + return context + +class ObjectAttentionBlock2D(_ObjectAttentionBlock): + def __init__(self, + in_channels, + key_channels, + scale=1, + bn_type=None): + super(ObjectAttentionBlock2D, self).__init__(in_channels, + key_channels, + scale, + bn_type=bn_type) + + +class SpatialOCR_Module(nn.Module): + """ + Implementation of the OCR module: + We aggregate the global object representation to update the representation for each pixel. + """ + def __init__(self, + in_channels, + key_channels, + out_channels, + scale=1, + dropout=0.1, + bn_type=None): + super(SpatialOCR_Module, self).__init__() + self.object_context_block = ObjectAttentionBlock2D(in_channels, + key_channels, + scale, + bn_type) + _in_channels = 2 * in_channels + + self.conv_bn_dropout = nn.Sequential( + nn.Conv2d(_in_channels, out_channels, kernel_size=1, padding=0, bias=False), + ModuleHelper.BNReLU(out_channels, bn_type=bn_type), + nn.Dropout2d(dropout) + ) + + def forward(self, feats, proxy_feats): + context = self.object_context_block(feats, proxy_feats) + + output = self.conv_bn_dropout(torch.cat([context, feats], 1)) + + return output + + + + +def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1): + """3x3 convolution with padding""" + return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, + padding=dilation, groups=groups, bias=False, dilation=dilation) + + +def conv1x1(in_planes, out_planes, stride=1): + """1x1 convolution""" + return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) + + +class BasicBlock(nn.Module): + expansion = 1 + + def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, + base_width=64, dilation=1, norm_layer=None): + super(BasicBlock, self).__init__() + if norm_layer is None: + norm_layer = nn.BatchNorm2d + if groups != 1 or base_width != 64: + raise ValueError('BasicBlock only supports groups=1 and base_width=64') + if dilation > 1: + raise NotImplementedError("Dilation > 1 not supported in BasicBlock") + # Both self.conv1 and self.downsample layers downsample the input when stride != 1 + self.conv1 = conv3x3(inplanes, planes, stride) + self.bn1 = norm_layer(planes) + self.relu = nn.ReLU(inplace=True) + self.conv2 = conv3x3(planes, planes) + self.bn2 = norm_layer(planes) + self.downsample = downsample + self.stride = stride + + def forward(self, x): + identity = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + + if self.downsample is not None: + identity = self.downsample(x) + + out += identity + out = self.relu(out) + + return out + + +class Bottleneck(nn.Module): + + expansion = 4 + + def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, + base_width=64, dilation=1, norm_layer=None): + super(Bottleneck, self).__init__() + if norm_layer is None: + norm_layer = nn.BatchNorm2d + width = int(planes * (base_width / 64.)) * groups + # Both self.conv2 and self.downsample layers downsample the input when stride != 1 + self.conv1 = conv1x1(inplanes, width) + self.bn1 = norm_layer(width) + self.conv2 = conv3x3(width, width, stride, groups, dilation) + self.bn2 = norm_layer(width) + self.conv3 = conv1x1(width, planes * self.expansion) + self.bn3 = norm_layer(planes * self.expansion) + self.relu = nn.ReLU(inplace=True) + self.downsample = downsample + self.stride = stride + + def forward(self, x): + identity = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + out = self.relu(out) + + out = self.conv3(out) + out = self.bn3(out) + + if self.downsample is not None: + identity = self.downsample(x) + + out += identity + out = self.relu(out) + + return out + + +class HighResolutionModule(nn.Module): + def __init__(self, num_branches, blocks, num_blocks, num_inchannels, + num_channels, fuse_method, multi_scale_output=True, norm_layer=None): + super(HighResolutionModule, self).__init__() + self._check_branches( + num_branches, blocks, num_blocks, num_inchannels, num_channels) + + if norm_layer is None: + norm_layer = nn.BatchNorm2d + self.norm_layer = norm_layer + + self.num_inchannels = num_inchannels + self.fuse_method = fuse_method + self.num_branches = num_branches + + self.multi_scale_output = multi_scale_output + + self.branches = self._make_branches( + num_branches, blocks, num_blocks, num_channels) + self.fuse_layers = self._make_fuse_layers() + self.relu = nn.ReLU(inplace=True) + + def _check_branches(self, num_branches, blocks, num_blocks, + num_inchannels, num_channels): + if num_branches != len(num_blocks): + error_msg = 'NUM_BRANCHES({}) <> NUM_BLOCKS({})'.format( + num_branches, len(num_blocks)) + logger.error(error_msg) + raise ValueError(error_msg) + + if num_branches != len(num_channels): + error_msg = 'NUM_BRANCHES({}) <> NUM_CHANNELS({})'.format( + num_branches, len(num_channels)) + logger.error(error_msg) + raise ValueError(error_msg) + + if num_branches != len(num_inchannels): + error_msg = 'NUM_BRANCHES({}) <> NUM_INCHANNELS({})'.format( + num_branches, len(num_inchannels)) + logger.error(error_msg) + raise ValueError(error_msg) + + def _make_one_branch(self, branch_index, block, num_blocks, num_channels, + stride=1): + downsample = None + if stride != 1 or \ + self.num_inchannels[branch_index] != num_channels[branch_index] * block.expansion: + downsample = nn.Sequential( + nn.Conv2d(self.num_inchannels[branch_index], + num_channels[branch_index] * block.expansion, + kernel_size=1, stride=stride, bias=False), + self.norm_layer(num_channels[branch_index] * block.expansion), + ) + + layers = [] + layers.append(block(self.num_inchannels[branch_index], + num_channels[branch_index], stride, downsample, norm_layer=self.norm_layer)) + self.num_inchannels[branch_index] = \ + num_channels[branch_index] * block.expansion + for i in range(1, num_blocks[branch_index]): + layers.append(block(self.num_inchannels[branch_index], + num_channels[branch_index], norm_layer=self.norm_layer)) + + return nn.Sequential(*layers) + + def _make_branches(self, num_branches, block, num_blocks, num_channels): + branches = [] + + for i in range(num_branches): + branches.append( + self._make_one_branch(i, block, num_blocks, num_channels)) + + return nn.ModuleList(branches) + + def _make_fuse_layers(self): + if self.num_branches == 1: + return None + + num_branches = self.num_branches + num_inchannels = self.num_inchannels + fuse_layers = [] + for i in range(num_branches if self.multi_scale_output else 1): + fuse_layer = [] + for j in range(num_branches): + if j > i: + fuse_layer.append(nn.Sequential( + nn.Conv2d(num_inchannels[j], + num_inchannels[i], + 1, + 1, + 0, + bias=False), + self.norm_layer(num_inchannels[i]))) + elif j == i: + fuse_layer.append(None) + else: + conv3x3s = [] + for k in range(i-j): + if k == i - j - 1: + num_outchannels_conv3x3 = num_inchannels[i] + conv3x3s.append(nn.Sequential( + nn.Conv2d(num_inchannels[j], + num_outchannels_conv3x3, + 3, 2, 1, bias=False), + self.norm_layer(num_outchannels_conv3x3))) + else: + num_outchannels_conv3x3 = num_inchannels[j] + conv3x3s.append(nn.Sequential( + nn.Conv2d(num_inchannels[j], + num_outchannels_conv3x3, + 3, 2, 1, bias=False), + self.norm_layer(num_outchannels_conv3x3), + nn.ReLU(inplace=True))) + fuse_layer.append(nn.Sequential(*conv3x3s)) + fuse_layers.append(nn.ModuleList(fuse_layer)) + + return nn.ModuleList(fuse_layers) + + def get_num_inchannels(self): + return self.num_inchannels + + def forward(self, x): + if self.num_branches == 1: + return [self.branches[0](x[0])] + + for i in range(self.num_branches): + x[i] = self.branches[i](x[i]) + + x_fuse = [] + for i in range(len(self.fuse_layers)): + y = x[0] if i == 0 else self.fuse_layers[i][0](x[0]) + for j in range(1, self.num_branches): + if i == j: + y = y + x[j] + elif j > i: + width_output = x[i].shape[-1] + height_output = x[i].shape[-2] + y = y + F.interpolate( + self.fuse_layers[i][j](x[j]), + size=[height_output, width_output], + mode='bilinear', + align_corners=True + ) + else: + y = y + self.fuse_layers[i][j](x[j]) + x_fuse.append(self.relu(y)) + + return x_fuse + + +blocks_dict = { + 'BASIC': BasicBlock, + 'BOTTLENECK': Bottleneck +} + + +class HighResolutionNet(nn.Module): + + def __init__(self, + cfg, + norm_layer=None): + super(HighResolutionNet, self).__init__() + + if norm_layer is None: + norm_layer = nn.BatchNorm2d + self.norm_layer = norm_layer + # stem network + # stem net + self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=2, padding=1, + bias=False) + self.bn1 = self.norm_layer(64) + self.conv2 = nn.Conv2d(64, 64, kernel_size=3, stride=2, padding=1, + bias=False) + self.bn2 = self.norm_layer(64) + self.relu = nn.ReLU(inplace=True) + + # stage 1 + self.stage1_cfg = cfg['STAGE1'] + num_channels = self.stage1_cfg['NUM_CHANNELS'][0] + block = blocks_dict[self.stage1_cfg['BLOCK']] + num_blocks = self.stage1_cfg['NUM_BLOCKS'][0] + self.layer1 = self._make_layer(block, 64, num_channels, num_blocks) + stage1_out_channel = block.expansion*num_channels + + # stage 2 + self.stage2_cfg = cfg['STAGE2'] + num_channels = self.stage2_cfg['NUM_CHANNELS'] + block = blocks_dict[self.stage2_cfg['BLOCK']] + num_channels = [ + num_channels[i] * block.expansion for i in range(len(num_channels))] + self.transition1 = self._make_transition_layer( + [stage1_out_channel], num_channels) + self.stage2, pre_stage_channels = self._make_stage( + self.stage2_cfg, num_channels) + + # stage 3 + self.stage3_cfg = cfg['STAGE3'] + num_channels = self.stage3_cfg['NUM_CHANNELS'] + block = blocks_dict[self.stage3_cfg['BLOCK']] + num_channels = [ + num_channels[i] * block.expansion for i in range(len(num_channels))] + self.transition2 = self._make_transition_layer( + pre_stage_channels, num_channels) + self.stage3, pre_stage_channels = self._make_stage( + self.stage3_cfg, num_channels) + + # stage 4 + self.stage4_cfg = cfg['STAGE4'] + num_channels = self.stage4_cfg['NUM_CHANNELS'] + block = blocks_dict[self.stage4_cfg['BLOCK']] + num_channels = [ + num_channels[i] * block.expansion for i in range(len(num_channels))] + self.transition3 = self._make_transition_layer( + pre_stage_channels, num_channels) + self.stage4, pre_stage_channels = self._make_stage( + self.stage4_cfg, num_channels, multi_scale_output=True) + + last_inp_channels = np.int(np.sum(pre_stage_channels)) + + MID_CHANNELS = 512 + KEY_CHANNELS = 256 + last_inp_channels = np.int(np.sum(pre_stage_channels)) + ocr_mid_channels = MID_CHANNELS + ocr_key_channels = KEY_CHANNELS + + self.conv3x3_ocr = nn.Sequential( + nn.Conv2d(last_inp_channels, ocr_mid_channels, + kernel_size=3, stride=1, padding=1), + norm_layer(ocr_mid_channels), + nn.ReLU(inplace=True), + ) + self.ocr_gather_head = SpatialGather_Module(9) + + self.ocr_distri_head = SpatialOCR_Module(in_channels=ocr_mid_channels, + key_channels=ocr_key_channels, + out_channels=ocr_mid_channels, + scale=1, + dropout=0.05, + ) + self.cls_head = nn.Conv2d( + ocr_mid_channels, 2, kernel_size=1, stride=1, padding=0, bias=True) + + self.aux_head = nn.Sequential( + nn.Conv2d(last_inp_channels, last_inp_channels, + kernel_size=1, stride=1, padding=0), + norm_layer(last_inp_channels), + nn.ReLU(inplace=True), + nn.Conv2d(last_inp_channels,2, + kernel_size=1, stride=1, padding=0, bias=True) + ) + + + + def _make_transition_layer( + self, num_channels_pre_layer, num_channels_cur_layer): + num_branches_cur = len(num_channels_cur_layer) + num_branches_pre = len(num_channels_pre_layer) + + transition_layers = [] + for i in range(num_branches_cur): + if i < num_branches_pre: + if num_channels_cur_layer[i] != num_channels_pre_layer[i]: + transition_layers.append(nn.Sequential( + nn.Conv2d(num_channels_pre_layer[i], + num_channels_cur_layer[i], + 3, + 1, + 1, + bias=False), + self.norm_layer(num_channels_cur_layer[i]), + nn.ReLU(inplace=True))) + else: + transition_layers.append(None) + else: + conv3x3s = [] + for j in range(i+1-num_branches_pre): + inchannels = num_channels_pre_layer[-1] + outchannels = num_channels_cur_layer[i] \ + if j == i-num_branches_pre else inchannels + conv3x3s.append(nn.Sequential( + nn.Conv2d( + inchannels, outchannels, 3, 2, 1, bias=False), + self.norm_layer(outchannels), + nn.ReLU(inplace=True))) + transition_layers.append(nn.Sequential(*conv3x3s)) + + return nn.ModuleList(transition_layers) + + def _make_layer(self, block, inplanes, planes, blocks, stride=1): + downsample = None + if stride != 1 or inplanes != planes * block.expansion: + downsample = nn.Sequential( + nn.Conv2d(inplanes, planes * block.expansion, + kernel_size=1, stride=stride, bias=False), + self.norm_layer(planes * block.expansion), + ) + + layers = [] + layers.append(block(inplanes, planes, stride, downsample, norm_layer=self.norm_layer)) + inplanes = planes * block.expansion + for i in range(1, blocks): + layers.append(block(inplanes, planes, norm_layer=self.norm_layer)) + + return nn.Sequential(*layers) + + def _make_stage(self, layer_config, num_inchannels, + multi_scale_output=True): + num_modules = layer_config['NUM_MODULES'] + num_branches = layer_config['NUM_BRANCHES'] + num_blocks = layer_config['NUM_BLOCKS'] + num_channels = layer_config['NUM_CHANNELS'] + block = blocks_dict[layer_config['BLOCK']] + fuse_method = layer_config['FUSE_METHOD'] + + modules = [] + for i in range(num_modules): + # multi_scale_output is only used last module + if not multi_scale_output and i == num_modules - 1: + reset_multi_scale_output = False + else: + reset_multi_scale_output = True + + modules.append( + HighResolutionModule(num_branches, + block, + num_blocks, + num_inchannels, + num_channels, + fuse_method, + reset_multi_scale_output, + norm_layer=self.norm_layer) + ) + num_inchannels = modules[-1].get_num_inchannels() + + return nn.Sequential(*modules), num_inchannels + + + def forward(self, input): + x = self.conv1(input) + x = self.bn1(x) + x = self.relu(x) + x = self.conv2(x) + x = self.bn2(x) + x = self.relu(x) + x = self.layer1(x) + + x_list = [] + for i in range(self.stage2_cfg['NUM_BRANCHES']): + if self.transition1[i] is not None: + x_list.append(self.transition1[i](x)) + else: + x_list.append(x) + y_list = self.stage2(x_list) + + x_list = [] + for i in range(self.stage3_cfg['NUM_BRANCHES']): + if self.transition2[i] is not None: + if i < self.stage2_cfg['NUM_BRANCHES']: + x_list.append(self.transition2[i](y_list[i])) + else: + x_list.append(self.transition2[i](y_list[-1])) + else: + x_list.append(y_list[i]) + y_list = self.stage3(x_list) + + x_list = [] + for i in range(self.stage4_cfg['NUM_BRANCHES']): + if self.transition3[i] is not None: + if i < self.stage3_cfg['NUM_BRANCHES']: + x_list.append(self.transition3[i](y_list[i])) + else: + x_list.append(self.transition3[i](y_list[-1])) + else: + x_list.append(y_list[i]) + x = self.stage4(x_list) + + outputs = {} + # See note [TorchScript super()] + outputs['res2'] = x[0] # 1/4 + outputs['res3'] = x[1] # 1/8 + outputs['res4'] = x[2] # 1/16 + outputs['res5'] = x[3] # 1/32 + x0_h, x0_w = x[0].size(2), x[0].size(3) + ALIGN_CORNERS = True + x1 = F.interpolate(x[1], size=(x0_h, x0_w), mode='bilinear', align_corners=ALIGN_CORNERS) + x2 = F.interpolate(x[2], size=(x0_h, x0_w), mode='bilinear', align_corners=ALIGN_CORNERS) + x3 = F.interpolate(x[3], size=(x0_h, x0_w), mode='bilinear', align_corners=ALIGN_CORNERS) + + feats = torch.cat([x[0], x1, x2, x3], 1) + + + out_aux_seg = [] + + # ocr + out_aux = self.aux_head(feats) + # compute contrast feature + feats = self.conv3x3_ocr(feats) + + context = self.ocr_gather_head(feats, out_aux) + feats = self.ocr_distri_head(feats, context) + + out = self.cls_head(feats) + + out_aux_seg.append(out_aux) + out_aux_seg.append(out) + + return out_aux_seg + + +def _hrnet(arch, pretrained, progress, **kwargs): + try: + from .hrnet_config import MODEL_CONFIGS + except ImportError: + from .hrnet_config import MODEL_CONFIGS + model = HighResolutionNet(MODEL_CONFIGS[arch], **kwargs) + if pretrained: + # if int(os.environ.get("mapillary_pretrain", 0)): + # logger.info("load the mapillary pretrained hrnet-w48 weights.") + # model_url = model_urls['hrnet48_mapillary_pretrain'] + # else: + # model_url = model_urls[arch] + + + pretrained_dict = torch.load('./pre-trained_weights/hrnetv2_w18_imagenet_pretrained.pth') + print('=> loading pretrained model {}'.format(pretrained)) + model_dict = model.state_dict() + pretrained_dict = {k: v for k, v in pretrained_dict.items() + if k in model_dict.keys()} + # for k, _ in pretrained_dict.items(): + # print('=> loading {} pretrained model {}'.format(k, pretrained)) + model_dict.update(pretrained_dict) + model.load_state_dict(model_dict) + # print(model.conv1.weight[0,0]) + return model + + +def hrnet18(pretrained=False, progress=True, **kwargs): + r"""HRNet-18 model + """ + return _hrnet('hrnet18', pretrained, progress, + **kwargs) + + +def hrnet32(pretrained=False, progress=True, **kwargs): + r"""HRNet-32 model + """ + return _hrnet('hrnet32', pretrained, progress, + **kwargs) + + +def hrnet48(pretrained=False, progress=True, **kwargs): + r"""HRNet-48 model + """ + return _hrnet('hrnet48', pretrained, progress, + **kwargs) +if __name__ == '__main__': + model = hrnet32(pretrained=True) + # print(model) + input = torch.randn([2,3,512,512],dtype=torch.float) + output = model(input) + print(output[0].shape) diff --git a/run.sh b/run.sh new file mode 100644 index 0000000..9362ed4 --- /dev/null +++ b/run.sh @@ -0,0 +1,2 @@ +python cut_data.py +CUDA_VISIBLE_DEVICES=0 python train.py --backbone=hrnet --batchsize=4 --lr=0.01 --num_epochs=150 diff --git a/train.py b/train.py new file mode 100644 index 0000000..a78fb04 --- /dev/null +++ b/train.py @@ -0,0 +1,265 @@ +#coding=utf-8 +from collections import defaultdict +import torch.nn.functional as F +from loss import calc_loss,calc_smoothloss +import time +import torch +from torch.utils.data import Dataset, DataLoader +from utils.lr_scheduler import adjust_learning_rate_poly +from utils.ema import WeightEMA +from utils.label2color import label_img_to_color, diff_label_img_to_color +from data.make_data import GaofenTrain, GaofenVal +from tqdm import tqdm +from evaluate import Evaluator +import numpy as np +import argparse +import matplotlib.pyplot as plt +import itertools +import torch.nn as nn +import cv2 +import os +from tensorboardX import SummaryWriter +from model.seg_hrnet import hrnet18 + +parser = argparse.ArgumentParser() +parser.add_argument('--root', default='./data') +parser.add_argument('--train_list_path', default='./data/train.txt') +parser.add_argument('--val_list_path', default='./data/val.txt') +# parser.add_argument('--test_list_path', default='./data/val_1w.txt') +parser.add_argument('--backbone', default='resnest50', type=str,help='xception|resnet|resnest101|resnest200|resnest50|resnest26') +parser.add_argument('--n_cls', default=2, type=int) +parser.add_argument('--batchsize', default=4, type=int) +parser.add_argument('--lr', default=0.01, type=float) +parser.add_argument('--num_epochs', default=150, type=int) +parser.add_argument('--warmup', default=100, type=int) +parser.add_argument('--multiplier', default=100, type=int) +parser.add_argument('--eta_min', default=0.0005, type=float) +parser.add_argument('--num_workers', default=8, type=int) +parser.add_argument('--decay_rate', default=0.8, type=float) +parser.add_argument('--decay_epoch', default=200, type=int) +parser.add_argument('--vis_frequency', default=30, type=int) +parser.add_argument('--save_path', default='./results') +parser.add_argument('--gpu-id', default='0,1', type=str, + help='id(s) for CUDA_VISIBLE_DEVICES') +parser.add_argument('--is_resume', default=False, type=bool) +parser.add_argument('--resume', default='', type=str, help='./results/checkpoint.pth') +args = parser.parse_args() + + +folder_path = '/backbone={}/warmup={}_lr={}_multiplier={}_eta_min={}_num_epochs={}_batchsize={}'.format(args.backbone,args.warmup,args.lr, args.multiplier, args.eta_min, args.num_epochs,args.batchsize) +isExists = os.path.exists(args.save_path + folder_path) +if not isExists: + os.makedirs(args.save_path + folder_path) + +isExists = os.path.exists(args.save_path +folder_path+'/vis') +if not isExists: os.makedirs(args.save_path + folder_path+'/vis') + +isExists = os.path.exists(args.save_path +folder_path+'/log') +if not isExists: os.makedirs(args.save_path + folder_path+'/log') + +os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu_id +device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') +global_step = 0 +def train_model(): + global global_step + F_txt = open('./opt_results.txt', 'w') + evaluator = Evaluator(args.n_cls) + classes = ['road', 'others'] + writer = SummaryWriter(args.save_path + folder_path + '/log') + def create_model(ema=False): + model = hrnet18(pretrained=True).to(device) + if ema: + for param in model.parameters(): + param.detach_() + return model + model = create_model() + ema_model = create_model(ema=True) + # model = hrnet18(pretrained=True).to(device) + # model = nn.DataParallel(model) + optimizer = torch.optim.SGD(model.parameters(), lr=args.lr, momentum=0.9, weight_decay=0.0001) + ema_optimizer = WeightEMA(model, ema_model, alpha=0.999) + best_miou = 0. + best_AA = 0. + best_OA = 0. + best_loss = 0. + lr = args.lr + epoch_index = 0 + if args.is_resume: + args.resume = args.save_path + folder_path+'/checkpoint_fwiou.pth' + if os.path.isfile(args.resume): + checkpoint = torch.load(args.resume) + epoch_index = checkpoint['epoch'] + best_miou = checkpoint['miou'] + model.load_state_dict(checkpoint['state_dict']) + optimizer.load_state_dict(checkpoint['optimizer']) + lr = optimizer.param_groups[0]['lr'] + print("=> loaded checkpoint '{}' (epoch {})".format(args.resume, checkpoint['epoch'])) + F_txt.write("=> loaded checkpoint '{}' (epoch {})".format(args.resume, checkpoint['epoch'])+'\n') + # print("=> loaded checkpoint '{}' (epoch {})".format(args.resume, checkpoint['epoch']), file=F_txt) + else: + print('EORRO: No such file!!!!!') + + TRAIN_DATA_DIRECTORY = args.root # '/media/ws/www/IGARSS' + TRAIN_DATA_LIST_PATH = args.train_list_path # '/media/ws/www/unet_1/data/train.txt' + + VAL_DATA_DIRECTORY = args.root # '/media/ws/www/IGARSS' + VAL_DATA_LIST_PATH = args.val_list_path # '/media/ws/www/unet_1/data/train.txt' + + + dataloaders = { + "train": DataLoader(GaofenTrain(TRAIN_DATA_DIRECTORY, TRAIN_DATA_LIST_PATH), batch_size=args.batchsize, + shuffle=True, num_workers=args.num_workers,pin_memory=True,drop_last=True), + "val": DataLoader(GaofenVal(VAL_DATA_DIRECTORY, VAL_DATA_LIST_PATH), batch_size=args.batchsize, + num_workers=args.num_workers,pin_memory=True) + } + + evaluator.reset() + print('config: ' + folder_path) + print('config: ' + folder_path, file=F_txt,flush=True) + for epoch in range(epoch_index, args.num_epochs): + print('Epoch [{}]/[{}] lr={:6f}'.format(epoch + 1, args.num_epochs, lr)) + # F_txt.write('Epoch [{}]/[{}] lr={:6f}'.format(epoch + 1, args.num_epochs, lr)+'\n',flush=True) + print('Epoch [{}]/[{}] lr={:4f}'.format(epoch + 1, args.num_epochs, lr), file=F_txt,flush=True) + since = time.time() + + # Each epoch has a training and validation phase + for phase in ['train', 'val']: + evaluator.reset() + if phase == 'train': + model.train() # Set model to training mode + else: + ema_model.eval() + model.eval() # Set model to evaluate mode + + metrics = defaultdict(float) + epoch_samples = 0 + + for i, (inputs, labels,edge, _, datafiles) in enumerate(tqdm(dataloaders[phase],ncols=50)): + inputs = inputs.to(device) + edge = edge.to(device, dtype = torch.float) + labels = labels.to(device, dtype=torch.long) + + optimizer.zero_grad() + + with torch.set_grad_enabled(phase == 'train'): + if phase == 'train': + outputs = model(inputs) + outputs[1] = F.interpolate(input=outputs[1], size=( + labels.shape[1],labels.shape[2]), mode='bilinear', align_corners=True) + loss = calc_loss(outputs, labels, edge, metrics) + pred = outputs[1].data.cpu().numpy() + pred = np.argmax(pred, axis=1) + labels = labels.data.cpu().numpy() + evaluator.add_batch(labels, pred) + if phase == 'val': + outputs = ema_model(inputs) + outputs[1] = F.interpolate(input=outputs[1], size=( + labels.shape[1],labels.shape[2]), mode='bilinear', align_corners=True) + loss = calc_loss(outputs, labels, edge, metrics) + pred = outputs[1].data.cpu().numpy() + pred = np.argmax(pred, axis=1) + labels = labels.data.cpu().numpy() + evaluator.add_batch(labels, pred) + if phase == 'val' and (epoch+1)%args.vis_frequency==0 and inputs.shape[0]==args.batchsize: + for k in range(args.batchsize//2): + name = datafiles['name'][k][:-4] + + writer.add_image('{}/img'.format(name),cv2.cvtColor(cv2.imread(datafiles["img"][k], cv2.IMREAD_COLOR), cv2.COLOR_BGR2RGB),global_step=int((epoch+1)),dataformats='HWC') + + writer.add_image('{}/gt'.format(name), label_img_to_color(labels[k])[:,:,::-1],global_step=int((epoch+1)),dataformats='HWC') + + pred_label_img = pred.astype(np.uint8)[k] + pred_label_img_color = label_img_to_color(pred_label_img) + writer.add_image('{}/mask'.format(name),pred_label_img_color[:,:,::-1],global_step=int((epoch+1)),dataformats='HWC') + + softmax_pred = F.softmax(outputs[1][k],dim=0) + softmax_pred_np = softmax_pred.data.cpu().numpy() + probility = softmax_pred_np[0] + probility = probility*255 + probility = probility.astype(np.uint8) + probility = cv2.applyColorMap(probility,cv2.COLORMAP_HOT) + writer.add_image('{}/prob'.format(name),cv2.cvtColor(probility,cv2.COLOR_BGR2RGB),global_step=int((epoch+1)),dataformats='HWC') + # 差分图 + diff_img = np.ones((pred_label_img.shape[0], pred_label_img.shape[1]), dtype=np.int32)*255 + mask = (labels[k] != pred_label_img) + diff_img[mask] = labels[k][mask] + diff_img_color = diff_label_img_to_color(diff_img) + writer.add_image('{}/different_image'.format(name), diff_img_color[:, :, ::-1], + global_step=int((epoch + 1)), dataformats='HWC') + if phase == 'train': + loss.backward() + global_step += 1 + optimizer.step() + ema_optimizer.step() + adjust_learning_rate_poly(args.lr,optimizer, epoch * len(dataloaders['train']) + i, + args.num_epochs * len(dataloaders['train'])) + lr = optimizer.param_groups[0]['lr'] + writer.add_scalar('lr', lr, global_step=epoch * len(dataloaders['train']) + i) + epoch_samples += 1 + epoch_loss = metrics['loss'] / epoch_samples + ce_loss = metrics['ce_loss'] / epoch_samples + ls_loss = metrics['ls_loss'] / epoch_samples + miou = evaluator.Mean_Intersection_over_Union() + AA = evaluator.Pixel_Accuracy_Class() + OA = evaluator.Pixel_Accuracy() + confusion_matrix = evaluator.confusion_matrix + if phase == 'val': + miou_mat = evaluator.Mean_Intersection_over_Union_test() + writer.add_scalar('val/val_loss', epoch_loss, global_step=epoch) + writer.add_scalar('val/ce_loss', ce_loss, global_step=epoch) + writer.add_scalar('val/ls_loss', ls_loss, global_step=epoch) + #writer.add_scalar('val/val_fwiou', fwiou, global_step=epoch) + writer.add_scalar('val/val_miou', miou, global_step=epoch) + for index in range(args.n_cls): + writer.add_scalar('class/{}'.format(index+1), miou_mat[index], global_step=epoch) + + print( + '[val]------miou: {:4f}, OA:{:4f}, AA: {:4f}, loss: {:4f}'.format( miou, OA, AA, + epoch_loss)) + print( + '[val]------miou: {:4f}, OA:{:4f}, AA: {:4f}, loss: {:4f}'.format(miou, OA, AA, + epoch_loss), + file=F_txt,flush=True) + if phase == 'train': + writer.add_scalar('train/train_loss', epoch_loss, global_step=epoch) + writer.add_scalar('train/ce_loss', ce_loss, global_step=epoch) + writer.add_scalar('train/ls_loss', ls_loss, global_step=epoch) + #writer.add_scalar('train/train_fwiou', fwiou, global_step=epoch) + writer.add_scalar('train/train_miou', miou, global_step=epoch) + print( + '[train]------miou: {:4f}, OA: {:4f}, AA: {:4f}, loss: {:4f}'.format( miou, OA, + AA, epoch_loss)) + print( + '[train]------miou: {:4f}, OA: {:4f}, AA: {:4f}, loss: {:4f}'.format(miou, OA, + AA, epoch_loss), + file=F_txt,flush=True) + + if phase == 'val' and miou > best_miou: + print("\33[91msaving best model miou\33[0m") + print("saving best model miou", file=F_txt,flush=True) + best_miou = miou + best_OA = OA + best_AA = AA + best_loss = epoch_loss + torch.save({ + 'name': 'resnest50_lovasz_edge_rotate', + 'epoch': epoch + 1, + 'state_dict': ema_model.state_dict(), + 'best_miou': best_miou + }, args.save_path + folder_path+'/model_best.pth') + torch.save({ + 'optimizer': optimizer.state_dict(), + }, args.save_path + folder_path+'/optimizer.pth') + time_elapsed = time.time() - since + print('{:.0f}m {:.0f}s'.format(time_elapsed//60, time_elapsed%60)) + print('{:.0f}m {:.0f}s'.format(time_elapsed//60, time_elapsed%60),file=F_txt,flush=True) + + print('[Best val]------miou: {:4f}; OA: {:4f}; AA: {:4f}; loss: {:4f}'.format(best_miou, + best_OA, best_AA, + best_loss)) + print('[Best val]------miou: {:4f}; OA: {:4f}; AA: {:4f}; loss: {:4f}'.format(best_miou, + best_OA, best_AA, + best_loss),file=F_txt,flush=True) + F_txt.close() +if __name__ == '__main__': + train_model() diff --git a/utils/__pycache__/ema.cpython-36.pyc b/utils/__pycache__/ema.cpython-36.pyc new file mode 100644 index 0000000..5df4054 Binary files /dev/null and b/utils/__pycache__/ema.cpython-36.pyc differ diff --git a/utils/__pycache__/label2color.cpython-35.pyc b/utils/__pycache__/label2color.cpython-35.pyc new file mode 100644 index 0000000..f7bbd27 Binary files /dev/null and b/utils/__pycache__/label2color.cpython-35.pyc differ diff --git a/utils/__pycache__/label2color.cpython-36.pyc b/utils/__pycache__/label2color.cpython-36.pyc new file mode 100644 index 0000000..ae85396 Binary files /dev/null and b/utils/__pycache__/label2color.cpython-36.pyc differ diff --git a/utils/__pycache__/lr_scheduler.cpython-35.pyc b/utils/__pycache__/lr_scheduler.cpython-35.pyc new file mode 100644 index 0000000..389e610 Binary files /dev/null and b/utils/__pycache__/lr_scheduler.cpython-35.pyc differ diff --git a/utils/__pycache__/lr_scheduler.cpython-36.pyc b/utils/__pycache__/lr_scheduler.cpython-36.pyc new file mode 100644 index 0000000..4600023 Binary files /dev/null and b/utils/__pycache__/lr_scheduler.cpython-36.pyc differ diff --git a/utils/ema.py b/utils/ema.py new file mode 100644 index 0000000..ef63259 --- /dev/null +++ b/utils/ema.py @@ -0,0 +1,20 @@ +import torch + +class WeightEMA(object): + def __init__(self, model, ema_model, alpha=0.999): + self.model = model + self.ema_model = ema_model + self.alpha = alpha + self.params = list(model.state_dict().values()) + self.ema_params = list(ema_model.state_dict().values()) + + for param, ema_param in zip(self.params, self.ema_params): + param.data.copy_(ema_param.data) + + def step(self): + one_minus_alpha = 1.0 - self.alpha + for param, ema_param in zip(self.params, self.ema_params): + if ema_param.dtype==torch.float32: + ema_param.mul_(self.alpha) + ema_param.add_(param * one_minus_alpha) + # customized weight decay \ No newline at end of file diff --git a/utils/label2color.py b/utils/label2color.py new file mode 100644 index 0000000..977d8b8 --- /dev/null +++ b/utils/label2color.py @@ -0,0 +1,24 @@ +import numpy as np + +def label_img_to_color(img): #bgr + label_to_color = { + 0: [0,0,0], + 1: [255, 255,255] + } + img_height, img_width = img.shape + img_color = np.zeros((img_height, img_width, 3),dtype=np.uint8) + for cls in range(2): + img_color[img==cls] = np.array(label_to_color[cls]) + return img_color + +def diff_label_img_to_color(img): #bgr + label_to_color = { + 255:[128,128,128], + 0: [0,0,0], + 1: [255, 255,255] + } + img_height, img_width = img.shape + img_color = np.zeros((img_height, img_width, 3),dtype=np.uint8) + for cls in [0,1,255]: + img_color[img==cls] = np.array(label_to_color[cls]) + return img_color \ No newline at end of file diff --git a/utils/lr_scheduler.py b/utils/lr_scheduler.py new file mode 100644 index 0000000..158184b --- /dev/null +++ b/utils/lr_scheduler.py @@ -0,0 +1,36 @@ +from torch.optim.lr_scheduler import _LRScheduler + +class GradualWarmupScheduler(_LRScheduler): + def __init__(self, optimizer, multiplier, total_epoch, after_scheduler=None): + self.multiplier = multiplier + self.total_epoch = total_epoch + self.after_scheduler = after_scheduler + self.finished = False + super().__init__(optimizer) + def get_lr(self): + if self.last_epoch > self.total_epoch: + if self.after_scheduler: + if not self.finished: + self.after_scheduler.base_lrs = [base_lr * self.multiplier for base_lr in self.base_lrs] + self.finished = True + return self.after_scheduler.get_lr() + return [base_lr * self.multiplier for base_lr in self.base_lrs] + return [base_lr * ((self.multiplier - 1.) * self.last_epoch / self.total_epoch + 1.) for base_lr in self.base_lrs] + def step(self, epoch=None, metrics=None): + if self.finished and self.after_scheduler: + if epoch is None: + self.after_scheduler.step(None) + else: + self.after_scheduler.step(epoch - self.total_epoch) + else: + return super(GradualWarmupScheduler, self).step(epoch) + +def lr_poly(base_lr,i_iter,max_iter,power): + return base_lr*((1-float(i_iter)/max_iter)**(power)) + +def adjust_learning_rate_poly(init_lr,optimizer,i_iter,max_iter): + lr = lr_poly(init_lr,i_iter,max_iter,0.9) + for param_group in optimizer.param_groups: + param_group['lr'] = lr + return lr +