diff --git a/backbone/incsar.py b/backbone/incsar.py new file mode 100644 index 0000000..278e2b6 --- /dev/null +++ b/backbone/incsar.py @@ -0,0 +1,27 @@ +import torch +from torch import nn +from torch.nn import functional as F +class sar_cnn(nn.Module): + def __init__(self, in_features= 3, out_features = 10): + super(sar_cnn, self).__init__() + self.in_features = in_features + self.out_features= out_features + self.conv16 = nn.Conv2d(in_channels=3, out_channels=16, kernel_size=7) + self.conv32 = nn.Conv2d(in_channels=16, out_channels=32, kernel_size=5) + self.conv64 = nn.Conv2d(in_channels=32, out_channels=64, kernel_size=5) + self.conv128 = nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3) + self.pool = nn.MaxPool2d(kernel_size=2, stride=2) + self.dropout = nn.Dropout(p=0.4) + self.flatten= nn.Flatten() + + def forward(self, x): + x = F.relu(self.conv16(x)) + x = self.pool(x) + x = F.relu(self.conv32(x)) + x = self.pool(x) + x = F.relu(self.conv64(x)) + x = self.pool(x) + x = F.relu(self.conv128(x)) + x = self.dropout(x) + x= self.flatten(x) + return x \ No newline at end of file diff --git a/exps/incsar.json b/exps/incsar.json new file mode 100644 index 0000000..a1e3139 --- /dev/null +++ b/exps/incsar.json @@ -0,0 +1,24 @@ +{ + "prefix": "reproduce", + "dataset": "mstar", + "shuffle": false, + "init_cls": 4, + "increment": 1, + "memory_size": 0, + "model_name": "incsar", + "backbone_type_vit": "pretrained_vit_b16_224_ssf", + "tuned_epoch_vit": 10, + "backbone_type_cnn": "sar_cnn", + "tuned_epoch_cnn": 30, + "device": ["0"], + "seed": [1993], + "eq_prot": false, + "init_lr": 0.01, + "batch_size": 48, + "weight_decay": 0.0005, + "min_lr": 0, + "optimizer": "sgd", + "lda": true, + "M": 10000, + "use_RP":true +} \ No newline at end of file diff --git a/models/base.py b/models/base.py index bd25ae3..e350c2e 100644 --- a/models/base.py +++ b/models/base.py @@ -444,3 +444,17 @@ def _compute_class_mean(self, data_manager, check_diff=False, oracle=False): self._class_means[class_idx, :] = class_mean self._class_covs[class_idx, ...] = class_cov + + def _eval_get_logits(self): + self._network.eval() + outputs,y_true = [], [] + for _, (_, inputs, targets) in enumerate(self.test_loader): + inputs = inputs.to(self._device) + with torch.no_grad(): + output = self._network(inputs)["logits"] + outputs.append(output) + y_true.append(targets.cpu().numpy()) + outputs_total = torch.cat(outputs, dim=0) + y_true_total = np.concatenate(y_true, axis=0) + + return outputs_total, y_true_total diff --git a/models/incsar.py b/models/incsar.py new file mode 100644 index 0000000..080fde3 --- /dev/null +++ b/models/incsar.py @@ -0,0 +1,174 @@ +import logging +import numpy as np +import torch +from torch import nn +from tqdm import tqdm +from torch import optim +from torch.nn import functional as F +from torch.utils.data import DataLoader +from utils.inc_net import SimpleVitNet +from models.base import BaseLearner +from utils.toolkit import target2onehot, tensor2numpy + +num_workers = 8 + +class Learner(BaseLearner): + def __init__(self, args): + super().__init__(args) + self._network = SimpleVitNet(args, True) + self.batch_size = args["batch_size"] + self.init_lr = args["init_lr"] + self.weight_decay = args["weight_decay"] if args["weight_decay"] is not None else 0.0005 + self.min_lr = args['min_lr'] if args['min_lr'] is not None else 1e-8 + self.args = args + self.current_class = 0 + def after_task(self): + self._known_classes = self._total_classes + + def replace_fc(self, trainloader, model, args): + model = model.eval() + embedding_list = [] + label_list = [] + with torch.no_grad(): + for i, batch in enumerate(trainloader): + (_,data, label) = batch + data = data.to(self._device) + label = label.to(self._device) + embedding = model.extract_vector(data) + embedding_list.append(embedding.cpu()) + label_list.append(label.cpu()) + embedding_list = torch.cat(embedding_list, dim=0) + label_list = torch.cat(label_list, dim=0) + + Y = target2onehot(label_list, self.args["nb_classes"]) + if self.args["use_RP"] == True: + Features_h = F.relu(embedding_list @ self.W_rand.cpu()) + else: + Features_h = embedding_list + # Equalization of Prototypes + if self.args["eq_prot"] == True: + class_counts = torch.bincount(label_list) + inv_class_frequencies = 1.0 / class_counts + for cls in range(self.current_class ,len(class_counts)): + cls_mask = (label_list == cls) + Features_h_cls = Features_h[cls_mask] + Y_cls = Y[cls_mask] + weight = inv_class_frequencies[cls] + self.Q[:, cls] += weight * (Features_h_cls.T @ Y_cls[:, cls]) + self.current_class += 1 + else: + self.Q = self.Q + Features_h.T @ Y + if self.args["lda"] == True: + self.G = self.G + Features_h.T @ Features_h + else: + Wo = self.Q.T + logging.info("Calculating ridge parameter") + ridge = self.optimise_ridge_parameter(Features_h, Y) + logging.info(f"ridge = {ridge}") + + Wo = torch.linalg.solve(self.G + ridge*torch.eye(self.G.size(dim=0)), self.Q).T # better nmerical stability than .invv + self._network.fc.weight.data = Wo[0:self._network.fc.weight.shape[0],:].to(self._device) + return model + + def setup_RP(self): + if self.args["use_RP"] == True : + M = self.args['M'] + self._network.RP_dim = M + self.W_rand = torch.randn(self._network.fc.in_features, M).to(self._device) + self._network.W_rand = self.W_rand + self._network.fc.weight = nn.Parameter(torch.Tensor(self._network.fc.out_features, M).to(self._device)).requires_grad_(False) # num classes in task x M + self.Q = torch.zeros(M, self.args["nb_classes"]) + self.G = torch.zeros(M, M) + else: + self._network.fc.weight = nn.Parameter(torch.Tensor(self._network.fc.out_features, self._network.feature_dim).to(self._device)).requires_grad_(False) # num classes in task x M + self.Q = torch.zeros(self._network.feature_dim, self.args["nb_classes"]) + self.G = torch.zeros(self._network.feature_dim, self._network.feature_dim) + + def optimise_ridge_parameter(self, Features, Y): + ridges = 10.0 ** np.arange(0, 8) + num_val_samples = int(Features.shape[0] * 0.8) + losses = [] + Q_val = Features[0:num_val_samples, :].T @ Y[0:num_val_samples, :] + G_val = Features[0:num_val_samples, :].T @ Features[0:num_val_samples, :] + for ridge in ridges: + Wo = torch.linalg.solve(G_val + ridge*torch.eye(G_val.size(dim=0)), Q_val).T #better nmerical stability than .inv + Y_train_pred = Features[num_val_samples::,:] @ Wo.T + losses.append(F.mse_loss(Y_train_pred, Y[num_val_samples::, :])) + ridge = ridges[np.argmin(np.array(losses))] + return ridge + + def incremental_train(self, data_manager): + self._cur_task += 1 + self._total_classes = self._known_classes + data_manager.get_task_size(self._cur_task) + self._network.update_fc(self._total_classes) + logging.info("Learning on {}-{}".format(self._known_classes, self._total_classes)) + + train_dataset = data_manager.get_dataset(np.arange(self._known_classes, self._total_classes),source="train", mode="train") + self.train_dataset=train_dataset + self.data_manager=data_manager + self.train_loader = DataLoader(train_dataset, batch_size=self.batch_size, shuffle=True, num_workers=num_workers) + + test_dataset = data_manager.get_dataset(np.arange(0, self._total_classes), source="test", mode="test") + self.test_loader = DataLoader(test_dataset, batch_size=self.batch_size, shuffle=False, num_workers=num_workers) + + train_dataset_for_protonet = data_manager.get_dataset(np.arange(self._known_classes, self._total_classes), source="train", mode="test") + self.train_loader_for_protonet = DataLoader(train_dataset_for_protonet, batch_size=self.batch_size, shuffle=True, num_workers=num_workers) + + if len(self._multiple_gpus) > 1: + print('Multiple GPUs') + self._network = nn.DataParallel(self._network, self._multiple_gpus) + self._train(self.train_loader, self.test_loader, self.train_loader_for_protonet) + + if len(self._multiple_gpus) > 1: + self._network = self._network.module + + def _train(self, train_loader, test_loader, train_loader_for_protonet): + self._network.to(self._device) + if self._cur_task == 0: + print("Finetune in Base Task:") + #total_params = sum(p.numel() for p in self._network.parameters()) + #print(f'{total_params:,} total parameters.') + optimizer = optim.SGD(self._network.parameters(), momentum=0.9, lr=self.init_lr,weight_decay=self.weight_decay) + scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=self.args['tuned_epoch'], eta_min=self.min_lr) + self._init_train(train_loader, test_loader, optimizer, scheduler) + else: + pass + if self._cur_task == 0: + self.setup_RP() + self.replace_fc(train_loader_for_protonet, self._network, None) + + def _init_train(self, train_loader, test_loader, optimizer, scheduler): + prog_bar = tqdm(range(self.args['tuned_epoch'])) + for _, epoch in enumerate(prog_bar): + self._network.train() + losses = 0.0 + correct, total = 0, 0 + for i, (_, inputs, targets) in enumerate(train_loader): + inputs, targets = inputs.to(self._device), targets.to(self._device) + logits = self._network(inputs)["logits"] + + loss = F.cross_entropy(logits, targets) + optimizer.zero_grad() + loss.backward() + optimizer.step() + losses += loss.item() + + _, preds = torch.max(logits, dim=1) + correct += preds.eq(targets.expand_as(preds)).cpu().sum() + total += len(targets) + + scheduler.step() + train_acc = np.around(tensor2numpy(correct) * 100 / total, decimals=2) + + test_acc = self._compute_accuracy(self._network, test_loader) + info = "Task {}, Epoch {}/{} => Loss {:.3f}, Train_accy {:.2f}, Test_accy {:.2f}".format( + self._cur_task, + epoch + 1, + self.args['tuned_epoch'], + losses / len(train_loader), + train_acc, + test_acc, + ) + prog_bar.set_description(info) + + logging.info(info) \ No newline at end of file diff --git a/trainer.py b/trainer.py index dccf0ee..097d940 100644 --- a/trainer.py +++ b/trainer.py @@ -16,7 +16,10 @@ def train(args): for seed in seed_list: args["seed"] = seed args["device"] = device - _train(args) + if args["model_name"] != 'incsar': + _train(args) + else: + _train_incsar(args) def _train(args): @@ -141,6 +144,113 @@ def _train(args): print(np_acctable) logging.info('Forgetting (NME): {}'.format(forgetting)) +def _train_incsar(args): + from utils.toolkit import accuracy + + init_cls = 0 if args ["init_cls"] == args["increment"] else args["init_cls"] + logs_name = "logs/{}/{}/{}/{}".format(args["model_name"],args["dataset"], init_cls, args['increment']) + + if not os.path.exists(logs_name): + os.makedirs(logs_name) + + logfilename = "logs/{}/{}/{}/{}/{}_{}_{}".format( + args["model_name"], + args["dataset"], + init_cls, + args["increment"], + args["prefix"], + args["seed"], + args["backbone_type_vit"], + ) + logging.basicConfig( + level=logging.INFO, + format="%(asctime)s [%(filename)s] => %(message)s", + handlers=[ + logging.FileHandler(filename=logfilename + ".log"), + logging.StreamHandler(sys.stdout), + ], + ) + + _set_random(args["seed"]) + _set_device(args) + print_args(args) + + args_vit = copy.deepcopy(args) + args_vit["backbone_type"] = args["backbone_type_vit"] + args_vit["tuned_epoch"]= args["tuned_epoch_vit"] + + data_manager_vit = DataManager( + args["dataset"], + args["shuffle"], + args["seed"], + args["init_cls"], + args["increment"], + args_vit, + ) + + args_vit["nb_classes"] = data_manager_vit.nb_classes # update args + args_vit["nb_tasks"] = data_manager_vit.nb_tasks + model_vit = factory.get_model(args_vit["model_name"], args_vit) + + args_cnn = copy.deepcopy(args) + args_cnn["backbone_type"] = args["backbone_type_cnn"] + args_cnn["tuned_epoch"]= args["tuned_epoch_cnn"] + + data_manager_cnn = DataManager( + args["dataset"], + args["shuffle"], + args["seed"], + args["init_cls"], + args["increment"], + args_cnn, + ) + + args_cnn["nb_classes"] = data_manager_cnn.nb_classes # update args + args_cnn["nb_tasks"] = data_manager_cnn.nb_tasks + model_cnn = factory.get_model(args_cnn["model_name"], args_cnn) + + cnn_curve, nme_curve = {"top1": [], "top5": []}, {"top1": [], "top5": []} + cnn_matrix, nme_matrix = [], [] + + for task in range(data_manager_vit.nb_tasks): + logging.info("ViT params: {}".format(count_parameters(model_vit._network))) + _set_random(args["seed"]) + model_vit.incremental_train(data_manager_vit) + + logits_vit,y_true = model_vit._eval_get_logits() + model_vit.after_task() + + logging.info("CNN params: {}".format(count_parameters(model_cnn._network))) + _set_random(args["seed"]) + model_cnn.incremental_train(data_manager_cnn) + logits_cnn, _ = model_cnn._eval_get_logits() + model_cnn.after_task() + + logits = (torch.nn.functional.softmax(logits_vit, dim=1) + torch.nn.functional.softmax(logits_cnn, dim=1)) / 2 + predicts = torch.topk(logits, k=1, dim=1, largest=True, sorted=True)[1] # [bs, topk] + predicts= predicts.cpu().numpy() + cnn_accy = {} + print(predicts.shape) + print(len(y_true)) + grouped = accuracy(predicts.T[0], y_true, logits_vit.size()[1]-1, init_cls, args['increment']) + cnn_accy["grouped"] = grouped + cnn_accy["top1"] = grouped["total"] + cnn_accy["top{}".format(1)] = np.around( + (predicts.T == np.tile(y_true, (1, 1))).sum() * 100 / len(y_true), + decimals=2, + ) + logging.info("Top-1 Accuracy per Task:: {}".format(cnn_accy["grouped"])) + + cnn_keys = [key for key in cnn_accy["grouped"].keys() if '-' in key] + cnn_values = [cnn_accy["grouped"][key] for key in cnn_keys] + cnn_matrix.append(cnn_values) + + cnn_curve["top1"].append(cnn_accy["top1"]) + + logging.info("Top-1 Accuracy curve: {}".format(cnn_curve["top1"])) + logging.info("Average Accuracy: {:.2f} \n".format(sum(cnn_curve["top1"])/len(cnn_curve["top1"]))) + + return cnn_curve["top1"], sum(cnn_curve["top1"])/len(cnn_curve["top1"]) def _set_device(args): device_type = args["device"] diff --git a/utils/data.py b/utils/data.py index 89c1dc1..c966174 100644 --- a/utils/data.py +++ b/utils/data.py @@ -92,6 +92,36 @@ def build_transform_coda_prompt(is_train, args): return t +def build_transform_incsar(is_train, args): + if args['backbone_type'] == 'sar_cnn': + input_size = 70 + crop_size= 32 + else: + input_size = 224 + crop_size= 64 + + if is_train: + scale = (0.8, 1.0) + transform = [ + transforms.CenterCrop(crop_size), + transforms.RandomResizedCrop(input_size, scale=scale), + transforms.RandomHorizontalFlip(p=0.5), + transforms.ToTensor(), + ] + else: + if args['backbone_type'] == 'sar_cnn': + transform = [ + transforms.CenterCrop(input_size), + transforms.ToTensor() + ] + else: + transform = [ + transforms.CenterCrop(crop_size), + transforms.Resize(input_size), + transforms.ToTensor() + ] + return transform + def build_transform(is_train, args): input_size = 224 resize_im = input_size > 32 @@ -342,4 +372,24 @@ def download_data(self): print(test_dset.class_to_idx) self.train_data, self.train_targets = split_images_labels(train_dset.imgs) + self.test_data, self.test_targets = split_images_labels(test_dset.imgs) + +class mstar(iData): + + def __init__(self, args): + super().__init__() + self.use_path = True + self.common_trsf = [ ] + + self.train_trsf = build_transform_incsar(True, args) + self.test_trsf = build_transform_incsar(False, args) + self.class_order = np.arange(10).tolist() + def download_data(self): + train_dir = "./datasets/MSTAR/train" + test_dir = "./datasets/MSTAR/test" + + train_dset = datasets.ImageFolder(train_dir) + test_dset = datasets.ImageFolder(test_dir) + + self.train_data, self.train_targets = split_images_labels(train_dset.imgs) self.test_data, self.test_targets = split_images_labels(test_dset.imgs) \ No newline at end of file diff --git a/utils/data_manager.py b/utils/data_manager.py index 828c91d..d0b3c28 100644 --- a/utils/data_manager.py +++ b/utils/data_manager.py @@ -3,7 +3,7 @@ from PIL import Image from torch.utils.data import Dataset from torchvision import transforms -from utils.data import iCIFAR10, iCIFAR100, iImageNet100, iImageNet1000, iCIFAR224, iImageNetR,iImageNetA,CUB, objectnet, omnibenchmark, vtab +from utils.data import iCIFAR10, iCIFAR100, iImageNet100, iImageNet1000, iCIFAR224, iImageNetR,iImageNetA,CUB, objectnet, omnibenchmark, vtab, mstar class DataManager(object): @@ -237,6 +237,8 @@ def _get_idata(dataset_name, args=None): return omnibenchmark() elif name == "vtab": return vtab() + elif name == 'mstar': + return mstar(args) else: raise NotImplementedError("Unknown dataset {}.".format(dataset_name)) diff --git a/utils/factory.py b/utils/factory.py index fd501ff..e6ed793 100644 --- a/utils/factory.py +++ b/utils/factory.py @@ -46,6 +46,8 @@ def get_model(model_name, args): from models.cofima import Learner elif name == 'duct': from models.duct import Learner + elif name == 'incsar': + from models.incsar import Learner else: assert 0 return Learner(args) \ No newline at end of file diff --git a/utils/inc_net.py b/utils/inc_net.py index 5120708..4e7ae95 100644 --- a/utils/inc_net.py +++ b/utils/inc_net.py @@ -27,7 +27,7 @@ def get_backbone(args, pretrained=False): return _basenet, _adaptive_net # SSF elif '_ssf' in name: - if args["model_name"] == "aper_ssf" or args["model_name"] == "ranpac" or args["model_name"] == "fecam": + if args["model_name"] == "aper_ssf" or args["model_name"] == "ranpac" or args["model_name"] == "fecam" or args["model_name"] == 'incsar': from backbone import vit_ssf if name == "pretrained_vit_b16_224_ssf": model = timm.create_model("vit_base_patch16_224_ssf", pretrained=True, num_classes=0) @@ -231,6 +231,11 @@ def get_backbone(args, pretrained=False): else: raise NotImplementedError("Unknown type {}".format(name)) return model + elif 'sar_cnn' in name: + from backbone.incsar import sar_cnn + model = sar_cnn() + model.out_dim= 1152 + return model.eval() else: raise NotImplementedError("Unknown type {}".format(name))