Skip to content

Commit

Permalink
fix bugs
Browse files Browse the repository at this point in the history
  • Loading branch information
tensorsketch committed Jul 17, 2020
1 parent 13f70fb commit 5e3e761
Show file tree
Hide file tree
Showing 180 changed files with 3,022 additions and 132 deletions.
31 changes: 13 additions & 18 deletions core/train_weight.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,6 @@ def weight(ten, a=10):
def lipton_weight(ten, beta = 4):
order = torch.argsort(ten)
return (order < len(ten)/(1+beta)).float()
def normalized_weight(ten, a = 10, reverse = False, lipton = False):
prob = torch.exp(ten[:,1])/ (torch.exp(ten[:,1]) + torch.exp(ten[:,0]))
if not reverse:
w = weight(prob)
else:
w = weight(1-prob)
return w

def get_quantile(ten, a = 0.5):
return torch.kthvalue(ten,math.floor(len(ten)*a))[0]
Expand Down Expand Up @@ -105,17 +98,18 @@ def train_dann(model, params, src_data_loader, tgt_data_loader, tgt_data_loader_
src_loss_domain = criterion0(src_domain_output, label_src)
else:
src_loss_domain = criterion(src_domain_output, label_src)
prob = torch.softmax(src_domain_output.data, dim = -1)
if params.soft:
if params.quantile:
weight_src[idx_src] = (torch.sort(src_domain_output.data[:,0])[1]).float().detach()
weight_src[idx_src] = (torch.sort(prob[:,1])[1]).float().detach()
else:
weight_src[idx_src] = normalized_weight(src_domain_output.data).detach()
weight_src[idx_src] = weight(prob[:,1]).detach()
else:
if params.quantile:
weight_src[idx_src] = (src_domain_output.data[:,0] < \
get_quantile(src_domain_output.data[:,0],params.threshold[0])).float().detach()
weight_src[idx_src] = (prob[:,0] < \
get_quantile(prob[:,0],params.threshold[0])).float().detach()
else:
weight_src[idx_src] = (src_domain_output.data[:,0] < params.threshold[0]).float().detach()
weight_src[idx_src] = (prob[:,0] < params.threshold[0]).float().detach()
src_loss_domain = torch.dot(weight_src[idx_src], src_loss_domain
)/ torch.sum(weight_src[idx_src])
#train on target domain
Expand All @@ -124,18 +118,19 @@ def train_dann(model, params, src_data_loader, tgt_data_loader, tgt_data_loader_
tgt_loss_domain = criterion0(tgt_domain_output, label_tgt)
else:
tgt_loss_domain = criterion(tgt_domain_output, label_tgt)
prob = torch.softmax(tgt_domain_output.data, dim = -1)

if params.soft:
if params.quantile:
weight_tgt[idx_tgt] = (torch.sort(tgt_domain_output.data[:,1])[1]).float().detach()
weight_tgt[idx_tgt] = (torch.sort(prob[:,0])[1]).float().detach()
else:
weight_tgt[idx_tgt] = normalized_weight(
tgt_domain_output.data, reverse = True).detach()
weight_tgt[idx_tgt] = weight(prob[:,0]).detach()
else:
if params.quantile:
weight_tgt[idx_tgt] = (tgt_domain_output.data[:,1] < \
get_quantile(tgt_domain_output.data[:,1],params.threshold[1])).float().detach()
weight_tgt[idx_tgt] = (prob[:,1] < \
get_quantile(prob[:,1],params.threshold[1])).float().detach()
else:
weight_tgt[idx_tgt] = (tgt_domain_output.data[:,1] < params.threshold[1]).float().detach()
weight_tgt[idx_tgt] = (prob[:,1] < params.threshold[1]).float().detach()
tgt_loss_domain = torch.dot(weight_tgt[idx_tgt], tgt_loss_domain
) / torch.sum(weight_tgt[idx_tgt])

Expand Down
2 changes: 1 addition & 1 deletion mnist_mnistm.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ class Config(object):
gpu_id = '0'

## for digit
num_epochs = 100
num_epochs = 1
log_step = 20
save_step = 50
eval_step = 1
Expand Down
8 changes: 6 additions & 2 deletions mnist_mnistm_weight.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@
from torch.utils.tensorboard import SummaryWriter
import shutil

for data_mode in range(6):
for run_mode in range(4):
for data_mode in [5]:
for run_mode in [1,2,3]:
class Config(object):
# params for path
model_name = "mnist-mnistm-weight"
Expand All @@ -21,6 +21,10 @@ class Config(object):
os.makedirs(model_root, exist_ok=True)
config = os.path.join(model_root, 'config.txt')
finetune_flag = False
data_mode = data_mode
run_mode = run_mode
soft = True
quantile = False
threshold = (0.5, 0.5)

# params for datasets and data loader
Expand Down
211 changes: 105 additions & 106 deletions mnist_usps_weight.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,109 +11,108 @@
import numpy as np
import shutil

for data_mode in [0,5,3]:
for run_mode in [0,1,2,3]:
for T in [0.55,0.6,0.8,0.9]:
class Config(object):
# params for path
model_name = "mnist-usps-weight"
dataset_root = get_dataset_root()
model_root = get_model_root(model_name, data_mode, run_mode)
model_root = os.path.join(model_root, datetime.datetime.now().strftime('%m%d_%H%M%S'))
os.makedirs(model_root, exist_ok=True)
config = os.path.join(model_root, 'config.txt')
finetune_flag = False
data_mode = data_mode
run_mode = run_mode
threshold = (T,T)
soft = False
quantile = False

# params for datasets and data loader
batch_size = 64

# params for source dataset
src_dataset = "mnist"
src_model_trained = True
src_classifier_restore = os.path.join(model_root, src_dataset + '-source-classifier-final.pt')
class_num_src = 31

# params for target dataset
tgt_dataset = "usps"
tgt_model_trained = True
dann_restore = os.path.join(model_root, src_dataset + '-' + tgt_dataset + '-dann-final.pt')

# params for pretrain
num_epochs_src = 100
log_step_src = 10
save_step_src = 50
eval_step_src = 20

# params for training dann
gpu_id = '0'

## for digit
num_epochs = 100
log_step = 20
save_step = 50
eval_step = 1

## for office
# num_epochs = 1000
# log_step = 10 # iters
# save_step = 500
# eval_step = 5 # epochs
lr_adjust_flag = 'simple'
src_only_flag = False

manual_seed = 8888
alpha = 0

# params for optimizing models
lr = 5e-4
momentum = 0
weight_decay = 0

def __init__(self):
public_props = (name for name in dir(self) if not name.startswith('_'))
with open(self.config, 'w') as f:
for name in public_props:
f.write(name + ': ' + str(getattr(self, name)) + '\n')

params = Config()

logger = SummaryWriter(params.model_root)

# init random seed
init_random_seed(params.manual_seed)

# init device
device = torch.device("cuda:" + params.gpu_id if torch.cuda.is_available() else "cpu")


print(data_mode, run_mode)
source_weight, target_weight = get_data(params.data_mode)
print(source_weight, target_weight)
if params.data_mode == 3:
src_data_loader, num_src_train = get_data_loader_weight(
params.src_dataset, params.dataset_root, params.batch_size, train=True, weights = source_weight)
src_data_loader_eval, _ = get_data_loader_weight(
params.src_dataset, params.dataset_root, params.batch_size, train=False, weights = source_weight)
else:
src_data_loader, num_src_train = get_data_loader_weight(params.src_dataset, params.dataset_root, params.batch_size, train=True)
src_data_loader_eval = get_data_loader(params.src_dataset, params.dataset_root, params.batch_size, train=False)
tgt_data_loader, num_tgt_train = get_data_loader_weight(
params.tgt_dataset, params.dataset_root, params.batch_size, train=True, weights = target_weight)
tgt_data_loader_eval, _ = get_data_loader_weight(
params.tgt_dataset, params.dataset_root, params.batch_size, train=False, weights = target_weight)
# Cannot use the same sampler for both training and testing dataset


# load dann model
dann = init_model(net=MNISTmodel(), restore=None)

# train dann model
print("Training dann model")
if not (dann.restored and params.dann_restore):
dann = train_dann(dann, params, src_data_loader, tgt_data_loader,
tgt_data_loader_eval, num_src_train, num_tgt_train, device, logger)
for data_mode in [1,2]:
for run_mode in [1,2,3]:
class Config(object):
# params for path
model_name = "mnist-usps-weight"
dataset_root = get_dataset_root()
model_root = get_model_root(model_name, data_mode, run_mode)
model_root = os.path.join(model_root, datetime.datetime.now().strftime('%m%d_%H%M%S'))
os.makedirs(model_root, exist_ok=True)
config = os.path.join(model_root, 'config.txt')
finetune_flag = False
data_mode = data_mode
run_mode = run_mode
threshold = (0.5,0.5)
soft = True
quantile = False

# params for datasets and data loader
batch_size = 64

# params for source dataset
src_dataset = "mnist"
src_model_trained = True
src_classifier_restore = os.path.join(model_root, src_dataset + '-source-classifier-final.pt')
class_num_src = 31

# params for target dataset
tgt_dataset = "usps"
tgt_model_trained = True
dann_restore = os.path.join(model_root, src_dataset + '-' + tgt_dataset + '-dann-final.pt')

# params for pretrain
num_epochs_src = 100
log_step_src = 10
save_step_src = 50
eval_step_src = 20

# params for training dann
gpu_id = '0'

## for digit
num_epochs = 100
log_step = 20
save_step = 50
eval_step = 1

## for office
# num_epochs = 1000
# log_step = 10 # iters
# save_step = 500
# eval_step = 5 # epochs
lr_adjust_flag = 'simple'
src_only_flag = False

manual_seed = 8888
alpha = 0

# params for optimizing models
lr = 5e-4
momentum = 0
weight_decay = 0

def __init__(self):
public_props = (name for name in dir(self) if not name.startswith('_'))
with open(self.config, 'w') as f:
for name in public_props:
f.write(name + ': ' + str(getattr(self, name)) + '\n')

params = Config()

logger = SummaryWriter(params.model_root)

# init random seed
init_random_seed(params.manual_seed)

# init device
device = torch.device("cuda:" + params.gpu_id if torch.cuda.is_available() else "cpu")


print(data_mode, run_mode)
source_weight, target_weight = get_data(params.data_mode)
print(source_weight, target_weight)
if params.data_mode == 3:
src_data_loader, num_src_train = get_data_loader_weight(
params.src_dataset, params.dataset_root, params.batch_size, train=True, weights = source_weight)
src_data_loader_eval, _ = get_data_loader_weight(
params.src_dataset, params.dataset_root, params.batch_size, train=False, weights = source_weight)
else:
src_data_loader, num_src_train = get_data_loader_weight(params.src_dataset, params.dataset_root, params.batch_size, train=True)
src_data_loader_eval = get_data_loader(params.src_dataset, params.dataset_root, params.batch_size, train=False)
tgt_data_loader, num_tgt_train = get_data_loader_weight(
params.tgt_dataset, params.dataset_root, params.batch_size, train=True, weights = target_weight)
tgt_data_loader_eval, _ = get_data_loader_weight(
params.tgt_dataset, params.dataset_root, params.batch_size, train=False, weights = target_weight)
# Cannot use the same sampler for both training and testing dataset


# load dann model
dann = init_model(net=MNISTmodel(), restore=None)

# train dann model
print("Training dann model")
if not (dann.restored and params.dann_restore):
dann = train_dann(dann, params, src_data_loader, tgt_data_loader,
tgt_data_loader_eval, num_src_train, num_tgt_train, device, logger)
34 changes: 34 additions & 0 deletions runs/mnist-mnistm-weight/data5/run0/0717_102639/config.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
alpha: 0
batch_size: 64
class_num_src: 31
config: runs/mnist-mnistm-weight/data5/run0/0717_102639/config.txt
dann_restore: runs/mnist-mnistm-weight/data5/run0/0717_102639/mnist-mnistm-dann-final.pt
data_mode: 5
dataset_root: /nobackup/yguo/datasets
eval_step: 1
eval_step_src: 20
finetune_flag: False
gpu_id: 0
log_step: 20
log_step_src: 10
lr: 0.0005
lr_adjust_flag: simple
manual_seed: 8888
model_name: mnist-mnistm-weight
model_root: runs/mnist-mnistm-weight/data5/run0/0717_102639
momentum: 0
num_epochs: 100
num_epochs_src: 100
quantile: False
run_mode: 0
save_step: 50
save_step_src: 50
soft: True
src_classifier_restore: runs/mnist-mnistm-weight/data5/run0/0717_102639/mnist-source-classifier-final.pt
src_dataset: mnist
src_model_trained: True
src_only_flag: False
tgt_dataset: mnistm
tgt_model_trained: True
threshold: (0.5, 0.5)
weight_decay: 0
Binary file not shown.
34 changes: 34 additions & 0 deletions runs/mnist-mnistm-weight/data5/run0/0717_135718/config.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
alpha: 0
batch_size: 64
class_num_src: 31
config: runs/mnist-mnistm-weight/data5/run0/0717_135718/config.txt
dann_restore: runs/mnist-mnistm-weight/data5/run0/0717_135718/mnist-mnistm-dann-final.pt
data_mode: 5
dataset_root: /nobackup/yguo/datasets
eval_step: 1
eval_step_src: 20
finetune_flag: False
gpu_id: 0
log_step: 20
log_step_src: 10
lr: 0.0005
lr_adjust_flag: simple
manual_seed: 8888
model_name: mnist-mnistm-weight
model_root: runs/mnist-mnistm-weight/data5/run0/0717_135718
momentum: 0
num_epochs: 100
num_epochs_src: 100
quantile: False
run_mode: 0
save_step: 50
save_step_src: 50
soft: True
src_classifier_restore: runs/mnist-mnistm-weight/data5/run0/0717_135718/mnist-source-classifier-final.pt
src_dataset: mnist
src_model_trained: True
src_only_flag: False
tgt_dataset: mnistm
tgt_model_trained: True
threshold: (0.5, 0.5)
weight_decay: 0
Binary file not shown.
Loading

0 comments on commit 5e3e761

Please sign in to comment.