Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 27 additions & 0 deletions backbone/incsar.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
import torch
from torch import nn
from torch.nn import functional as F
class sar_cnn(nn.Module):
def __init__(self, in_features= 3, out_features = 10):
super(sar_cnn, self).__init__()
self.in_features = in_features
self.out_features= out_features
self.conv16 = nn.Conv2d(in_channels=3, out_channels=16, kernel_size=7)
self.conv32 = nn.Conv2d(in_channels=16, out_channels=32, kernel_size=5)
self.conv64 = nn.Conv2d(in_channels=32, out_channels=64, kernel_size=5)
self.conv128 = nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3)
self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
self.dropout = nn.Dropout(p=0.4)
self.flatten= nn.Flatten()

def forward(self, x):
x = F.relu(self.conv16(x))
x = self.pool(x)
x = F.relu(self.conv32(x))
x = self.pool(x)
x = F.relu(self.conv64(x))
x = self.pool(x)
x = F.relu(self.conv128(x))
x = self.dropout(x)
x= self.flatten(x)
return x
24 changes: 24 additions & 0 deletions exps/incsar.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
{
"prefix": "reproduce",
"dataset": "mstar",
"shuffle": false,
"init_cls": 4,
"increment": 1,
"memory_size": 0,
"model_name": "incsar",
"backbone_type_vit": "pretrained_vit_b16_224_ssf",
"tuned_epoch_vit": 10,
"backbone_type_cnn": "sar_cnn",
"tuned_epoch_cnn": 30,
"device": ["0"],
"seed": [1993],
"eq_prot": false,
"init_lr": 0.01,
"batch_size": 48,
"weight_decay": 0.0005,
"min_lr": 0,
"optimizer": "sgd",
"lda": true,
"M": 10000,
"use_RP":true
}
14 changes: 14 additions & 0 deletions models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -444,3 +444,17 @@ def _compute_class_mean(self, data_manager, check_diff=False, oracle=False):

self._class_means[class_idx, :] = class_mean
self._class_covs[class_idx, ...] = class_cov

def _eval_get_logits(self):
self._network.eval()
outputs,y_true = [], []
for _, (_, inputs, targets) in enumerate(self.test_loader):
inputs = inputs.to(self._device)
with torch.no_grad():
output = self._network(inputs)["logits"]
outputs.append(output)
y_true.append(targets.cpu().numpy())
outputs_total = torch.cat(outputs, dim=0)
y_true_total = np.concatenate(y_true, axis=0)

return outputs_total, y_true_total
174 changes: 174 additions & 0 deletions models/incsar.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,174 @@
import logging
import numpy as np
import torch
from torch import nn
from tqdm import tqdm
from torch import optim
from torch.nn import functional as F
from torch.utils.data import DataLoader
from utils.inc_net import SimpleVitNet
from models.base import BaseLearner
from utils.toolkit import target2onehot, tensor2numpy

num_workers = 8

class Learner(BaseLearner):
def __init__(self, args):
super().__init__(args)
self._network = SimpleVitNet(args, True)
self.batch_size = args["batch_size"]
self.init_lr = args["init_lr"]
self.weight_decay = args["weight_decay"] if args["weight_decay"] is not None else 0.0005
self.min_lr = args['min_lr'] if args['min_lr'] is not None else 1e-8
self.args = args
self.current_class = 0
def after_task(self):
self._known_classes = self._total_classes

def replace_fc(self, trainloader, model, args):
model = model.eval()
embedding_list = []
label_list = []
with torch.no_grad():
for i, batch in enumerate(trainloader):
(_,data, label) = batch
data = data.to(self._device)
label = label.to(self._device)
embedding = model.extract_vector(data)
embedding_list.append(embedding.cpu())
label_list.append(label.cpu())
embedding_list = torch.cat(embedding_list, dim=0)
label_list = torch.cat(label_list, dim=0)

Y = target2onehot(label_list, self.args["nb_classes"])
if self.args["use_RP"] == True:
Features_h = F.relu(embedding_list @ self.W_rand.cpu())
else:
Features_h = embedding_list
# Equalization of Prototypes
if self.args["eq_prot"] == True:
class_counts = torch.bincount(label_list)
inv_class_frequencies = 1.0 / class_counts
for cls in range(self.current_class ,len(class_counts)):
cls_mask = (label_list == cls)
Features_h_cls = Features_h[cls_mask]
Y_cls = Y[cls_mask]
weight = inv_class_frequencies[cls]
self.Q[:, cls] += weight * (Features_h_cls.T @ Y_cls[:, cls])
self.current_class += 1
else:
self.Q = self.Q + Features_h.T @ Y
if self.args["lda"] == True:
self.G = self.G + Features_h.T @ Features_h
else:
Wo = self.Q.T
logging.info("Calculating ridge parameter")
ridge = self.optimise_ridge_parameter(Features_h, Y)
logging.info(f"ridge = {ridge}")

Wo = torch.linalg.solve(self.G + ridge*torch.eye(self.G.size(dim=0)), self.Q).T # better nmerical stability than .invv
self._network.fc.weight.data = Wo[0:self._network.fc.weight.shape[0],:].to(self._device)
return model

def setup_RP(self):
if self.args["use_RP"] == True :
M = self.args['M']
self._network.RP_dim = M
self.W_rand = torch.randn(self._network.fc.in_features, M).to(self._device)
self._network.W_rand = self.W_rand
self._network.fc.weight = nn.Parameter(torch.Tensor(self._network.fc.out_features, M).to(self._device)).requires_grad_(False) # num classes in task x M
self.Q = torch.zeros(M, self.args["nb_classes"])
self.G = torch.zeros(M, M)
else:
self._network.fc.weight = nn.Parameter(torch.Tensor(self._network.fc.out_features, self._network.feature_dim).to(self._device)).requires_grad_(False) # num classes in task x M
self.Q = torch.zeros(self._network.feature_dim, self.args["nb_classes"])
self.G = torch.zeros(self._network.feature_dim, self._network.feature_dim)

def optimise_ridge_parameter(self, Features, Y):
ridges = 10.0 ** np.arange(0, 8)
num_val_samples = int(Features.shape[0] * 0.8)
losses = []
Q_val = Features[0:num_val_samples, :].T @ Y[0:num_val_samples, :]
G_val = Features[0:num_val_samples, :].T @ Features[0:num_val_samples, :]
for ridge in ridges:
Wo = torch.linalg.solve(G_val + ridge*torch.eye(G_val.size(dim=0)), Q_val).T #better nmerical stability than .inv
Y_train_pred = Features[num_val_samples::,:] @ Wo.T
losses.append(F.mse_loss(Y_train_pred, Y[num_val_samples::, :]))
ridge = ridges[np.argmin(np.array(losses))]
return ridge

def incremental_train(self, data_manager):
self._cur_task += 1
self._total_classes = self._known_classes + data_manager.get_task_size(self._cur_task)
self._network.update_fc(self._total_classes)
logging.info("Learning on {}-{}".format(self._known_classes, self._total_classes))

train_dataset = data_manager.get_dataset(np.arange(self._known_classes, self._total_classes),source="train", mode="train")
self.train_dataset=train_dataset
self.data_manager=data_manager
self.train_loader = DataLoader(train_dataset, batch_size=self.batch_size, shuffle=True, num_workers=num_workers)

test_dataset = data_manager.get_dataset(np.arange(0, self._total_classes), source="test", mode="test")
self.test_loader = DataLoader(test_dataset, batch_size=self.batch_size, shuffle=False, num_workers=num_workers)

train_dataset_for_protonet = data_manager.get_dataset(np.arange(self._known_classes, self._total_classes), source="train", mode="test")
self.train_loader_for_protonet = DataLoader(train_dataset_for_protonet, batch_size=self.batch_size, shuffle=True, num_workers=num_workers)

if len(self._multiple_gpus) > 1:
print('Multiple GPUs')
self._network = nn.DataParallel(self._network, self._multiple_gpus)
self._train(self.train_loader, self.test_loader, self.train_loader_for_protonet)

if len(self._multiple_gpus) > 1:
self._network = self._network.module

def _train(self, train_loader, test_loader, train_loader_for_protonet):
self._network.to(self._device)
if self._cur_task == 0:
print("Finetune in Base Task:")
#total_params = sum(p.numel() for p in self._network.parameters())
#print(f'{total_params:,} total parameters.')
optimizer = optim.SGD(self._network.parameters(), momentum=0.9, lr=self.init_lr,weight_decay=self.weight_decay)
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=self.args['tuned_epoch'], eta_min=self.min_lr)
self._init_train(train_loader, test_loader, optimizer, scheduler)
else:
pass
if self._cur_task == 0:
self.setup_RP()
self.replace_fc(train_loader_for_protonet, self._network, None)

def _init_train(self, train_loader, test_loader, optimizer, scheduler):
prog_bar = tqdm(range(self.args['tuned_epoch']))
for _, epoch in enumerate(prog_bar):
self._network.train()
losses = 0.0
correct, total = 0, 0
for i, (_, inputs, targets) in enumerate(train_loader):
inputs, targets = inputs.to(self._device), targets.to(self._device)
logits = self._network(inputs)["logits"]

loss = F.cross_entropy(logits, targets)
optimizer.zero_grad()
loss.backward()
optimizer.step()
losses += loss.item()

_, preds = torch.max(logits, dim=1)
correct += preds.eq(targets.expand_as(preds)).cpu().sum()
total += len(targets)

scheduler.step()
train_acc = np.around(tensor2numpy(correct) * 100 / total, decimals=2)

test_acc = self._compute_accuracy(self._network, test_loader)
info = "Task {}, Epoch {}/{} => Loss {:.3f}, Train_accy {:.2f}, Test_accy {:.2f}".format(
self._cur_task,
epoch + 1,
self.args['tuned_epoch'],
losses / len(train_loader),
train_acc,
test_acc,
)
prog_bar.set_description(info)

logging.info(info)
112 changes: 111 additions & 1 deletion trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,10 @@ def train(args):
for seed in seed_list:
args["seed"] = seed
args["device"] = device
_train(args)
if args["model_name"] != 'incsar':
_train(args)
else:
_train_incsar(args)


def _train(args):
Expand Down Expand Up @@ -141,6 +144,113 @@ def _train(args):
print(np_acctable)
logging.info('Forgetting (NME): {}'.format(forgetting))

def _train_incsar(args):
from utils.toolkit import accuracy

init_cls = 0 if args ["init_cls"] == args["increment"] else args["init_cls"]
logs_name = "logs/{}/{}/{}/{}".format(args["model_name"],args["dataset"], init_cls, args['increment'])

if not os.path.exists(logs_name):
os.makedirs(logs_name)

logfilename = "logs/{}/{}/{}/{}/{}_{}_{}".format(
args["model_name"],
args["dataset"],
init_cls,
args["increment"],
args["prefix"],
args["seed"],
args["backbone_type_vit"],
)
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s [%(filename)s] => %(message)s",
handlers=[
logging.FileHandler(filename=logfilename + ".log"),
logging.StreamHandler(sys.stdout),
],
)

_set_random(args["seed"])
_set_device(args)
print_args(args)

args_vit = copy.deepcopy(args)
args_vit["backbone_type"] = args["backbone_type_vit"]
args_vit["tuned_epoch"]= args["tuned_epoch_vit"]

data_manager_vit = DataManager(
args["dataset"],
args["shuffle"],
args["seed"],
args["init_cls"],
args["increment"],
args_vit,
)

args_vit["nb_classes"] = data_manager_vit.nb_classes # update args
args_vit["nb_tasks"] = data_manager_vit.nb_tasks
model_vit = factory.get_model(args_vit["model_name"], args_vit)

args_cnn = copy.deepcopy(args)
args_cnn["backbone_type"] = args["backbone_type_cnn"]
args_cnn["tuned_epoch"]= args["tuned_epoch_cnn"]

data_manager_cnn = DataManager(
args["dataset"],
args["shuffle"],
args["seed"],
args["init_cls"],
args["increment"],
args_cnn,
)

args_cnn["nb_classes"] = data_manager_cnn.nb_classes # update args
args_cnn["nb_tasks"] = data_manager_cnn.nb_tasks
model_cnn = factory.get_model(args_cnn["model_name"], args_cnn)

cnn_curve, nme_curve = {"top1": [], "top5": []}, {"top1": [], "top5": []}
cnn_matrix, nme_matrix = [], []

for task in range(data_manager_vit.nb_tasks):
logging.info("ViT params: {}".format(count_parameters(model_vit._network)))
_set_random(args["seed"])
model_vit.incremental_train(data_manager_vit)

logits_vit,y_true = model_vit._eval_get_logits()
model_vit.after_task()

logging.info("CNN params: {}".format(count_parameters(model_cnn._network)))
_set_random(args["seed"])
model_cnn.incremental_train(data_manager_cnn)
logits_cnn, _ = model_cnn._eval_get_logits()
model_cnn.after_task()

logits = (torch.nn.functional.softmax(logits_vit, dim=1) + torch.nn.functional.softmax(logits_cnn, dim=1)) / 2
predicts = torch.topk(logits, k=1, dim=1, largest=True, sorted=True)[1] # [bs, topk]
predicts= predicts.cpu().numpy()
cnn_accy = {}
print(predicts.shape)
print(len(y_true))
grouped = accuracy(predicts.T[0], y_true, logits_vit.size()[1]-1, init_cls, args['increment'])
cnn_accy["grouped"] = grouped
cnn_accy["top1"] = grouped["total"]
cnn_accy["top{}".format(1)] = np.around(
(predicts.T == np.tile(y_true, (1, 1))).sum() * 100 / len(y_true),
decimals=2,
)
logging.info("Top-1 Accuracy per Task:: {}".format(cnn_accy["grouped"]))

cnn_keys = [key for key in cnn_accy["grouped"].keys() if '-' in key]
cnn_values = [cnn_accy["grouped"][key] for key in cnn_keys]
cnn_matrix.append(cnn_values)

cnn_curve["top1"].append(cnn_accy["top1"])

logging.info("Top-1 Accuracy curve: {}".format(cnn_curve["top1"]))
logging.info("Average Accuracy: {:.2f} \n".format(sum(cnn_curve["top1"])/len(cnn_curve["top1"])))

return cnn_curve["top1"], sum(cnn_curve["top1"])/len(cnn_curve["top1"])

def _set_device(args):
device_type = args["device"]
Expand Down
Loading