Skip to content

Commit

Permalink
finish the implementation
Browse files Browse the repository at this point in the history
  • Loading branch information
tensorsketch committed Jul 15, 2020
1 parent 12a58f4 commit 579288f
Show file tree
Hide file tree
Showing 3 changed files with 159 additions and 95 deletions.
18 changes: 10 additions & 8 deletions core/train_weight.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def normalized_weight(ten, a = 10, reverse = False):
w = weight(1-prob)
return w

def train_dann(model, params, src_data_loader, tgt_data_loader, tgt_data_loader_eval, num_src, num_tgt, device, logger, mode = "DANN"):
def train_dann(model, params, src_data_loader, tgt_data_loader, tgt_data_loader_eval, num_src, num_tgt, device, logger, mode = 0):
"""Train dann."""
####################
# 1. setup network #
Expand Down Expand Up @@ -97,20 +97,22 @@ def train_dann(model, params, src_data_loader, tgt_data_loader, tgt_data_loader_
# train on source domain
src_class_output, src_domain_output = model(input_data=images_src, alpha=alpha)
src_loss_class = criterion0(src_class_output, class_src)
if mode == "DANN":
if mode in [0,2]:
src_loss_domain = criterion0(src_domain_output, label_src)
elif mode == "WINSORIZE":
else:
src_loss_domain = criterion(src_domain_output, label_src)
weight_src[idx_src] = normalized_weight(src_domain_output.data).detach()
src_loss_domain = torch.dot(weight_src[idx_src], src_loss_domain
)/ torch.sum(weight_src[idx_src])
#train on target domain
_, tgt_domain_output = model(input_data=images_tgt, alpha=alpha)
tgt_loss_domain = criterion0(tgt_domain_output, label_tgt)
# weight_tgt[idx_tgt] = normalized_weight(
# tgt_domain_output.data, reverse = True).detach()
# tgt_loss_domain = torch.dot(weight_tgt[idx_tgt], tgt_loss_domain
# ) / torch.sum(weight_tgt[idx_tgt])
if mode in [0,1]:
tgt_loss_domain = criterion0(tgt_domain_output, label_tgt)
else:
weight_tgt[idx_tgt] = normalized_weight(
tgt_domain_output.data, reverse = True).detach()
tgt_loss_domain = torch.dot(weight_tgt[idx_tgt], tgt_loss_domain
) / torch.sum(weight_tgt[idx_tgt])


loss = src_loss_class + src_loss_domain + tgt_loss_domain
Expand Down
178 changes: 91 additions & 87 deletions mnist_mnistm_weight.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,92 +5,96 @@
sys.path.append('../')
from models.model import MNISTmodel, MNISTmodel_plain
from core.train_weight import train_dann
from utils.utils import get_data_loader, get_data_loader_weight, init_model, init_random_seed, get_dataset_root
from utils.utils import get_data_loader, get_data_loader_weight, init_model, init_random_seed, get_dataset_root, get_model_root, get_data
from torch.utils.tensorboard import SummaryWriter



class Config(object):
# params for path
model_name = "mnist-mnistm-weight"
dataset_root = get_dataset_root()
model_root = os.path.expanduser(os.path.join('runs', model_name))
finetune_flag = False

# params for datasets and data loader
batch_size = 64

# params for source dataset
src_dataset = "mnist"
src_model_trained = True
src_classifier_restore = os.path.join(model_root, src_dataset + '-source-classifier-final.pt')
class_num_src = 31

# params for target dataset
tgt_dataset = "mnistm"
tgt_model_trained = True
dann_restore = os.path.join(model_root, src_dataset + '-' + tgt_dataset + '-dann-final.pt')

# params for pretrain
num_epochs_src = 100
log_step_src = 10
save_step_src = 50
eval_step_src = 20

# params for training dann
gpu_id = '0'

## for digit
num_epochs = 100
log_step = 20
save_step = 50
eval_step = 1

## for office
# num_epochs = 1000
# log_step = 10 # iters
# save_step = 500
# eval_step = 5 # epochs
lr_adjust_flag = 'simple'
src_only_flag = False

manual_seed = 8888
alpha = 0

# params for optimizing models
lr = 5e-4
momentum = 0
weight_decay = 0

params = Config()

logger = SummaryWriter(params.model_root)

# init random seed
init_random_seed(params.manual_seed)

# init device
device = torch.device("cuda:" + params.gpu_id if torch.cuda.is_available() else "cpu")


WEIGHTS = torch.ones(10)
# load dataset
src_data_loader, num_src_train = get_data_loader_weight(params.src_dataset, params.dataset_root, params.batch_size, train=True)
src_data_loader_eval = get_data_loader(params.src_dataset, params.dataset_root, params.batch_size, train=False)
tgt_data_loader, num_tgt_train = get_data_loader_weight(
params.tgt_dataset, params.dataset_root, params.batch_size, train=True, sampler=torch.utils.data.sampler.WeightedRandomSampler(
WEIGHTS, 1))
tgt_data_loader_eval, _ = get_data_loader_weight(
params.tgt_dataset, params.dataset_root, params.batch_size, train=False, sampler=torch.utils.data.sampler.WeightedRandomSampler(
WEIGHTS, 1))
# Cannot use the same sampler for both training and testing dataset


# load dann model
dann = init_model(net=MNISTmodel(), restore=None)

# train dann model
print("Training dann model")
if not (dann.restored and params.dann_restore):
dann = train_dann(dann, params, src_data_loader, tgt_data_loader,
tgt_data_loader_eval, num_src_train, num_tgt_train, device, logger)
for data_mode, run_mode in zip(range(4), range(4)):
class Config(object):
# params for path
model_name = "mnist-mnistm-weight"
dataset_root = get_dataset_root()
model_root = get_model_root(model_name, data_mode, run_mode)
finetune_flag = False

# params for datasets and data loader
batch_size = 64

# params for source dataset
src_dataset = "mnist"
src_model_trained = True
src_classifier_restore = os.path.join(model_root, src_dataset + '-source-classifier-final.pt')
class_num_src = 31
data_mode = data_mode
run_mode = run_mode

# params for target dataset
tgt_dataset = "mnistm"
tgt_model_trained = True
dann_restore = os.path.join(model_root, src_dataset + '-' + tgt_dataset + '-dann-final.pt')

# params for pretrain
num_epochs_src = 100
log_step_src = 10
save_step_src = 50
eval_step_src = 20

# params for training dann
gpu_id = '0'

## for digit
num_epochs = 1
log_step = 20
save_step = 50
eval_step = 1

## for office
# num_epochs = 1000
# log_step = 10 # iters
# save_step = 500
# eval_step = 5 # epochs
lr_adjust_flag = 'simple'
src_only_flag = False

manual_seed = 8888
alpha = 0

# params for optimizing models
lr = 5e-4
momentum = 0
weight_decay = 0

params = Config()

logger = SummaryWriter(params.model_root)

# init random seed
init_random_seed(params.manual_seed)

# init device
device = torch.device("cuda:" + params.gpu_id if torch.cuda.is_available() else "cpu")

source_weight, target_weight = get_data(params.data_mode)

# load dataset
if params.data_mode == 3:
src_data_loader, num_src_train = get_data_loader_weight(
params.src_dataset, params.src_image_root, params.batch_size, train=True, weights = source_weight)
src_data_loader_eval, _ = get_data_loader_weight(
params.src_dataset, params.src_image_root, params.batch_size, train=False, weights = source_weight)
else:
src_data_loader, num_src_train = get_data_loader_weight(params.src_dataset, params.dataset_root, params.batch_size, train=True)
src_data_loader_eval = get_data_loader(params.src_dataset, params.dataset_root, params.batch_size, train=False)
tgt_data_loader, num_tgt_train = get_data_loader_weight(
params.tgt_dataset, params.tgt_image_root, params.batch_size, train=True, weights = target_weight)
tgt_data_loader_eval, _ = get_data_loader_weight(
params.tgt_dataset,
params.tgt_image_root, params.batch_size, train=False, weights = target_weight)

# load dann model
dann = init_model(net=MNISTmodel(), restore=None)

# train dann model
print("Training dann model")
if not (dann.restored and params.dann_restore):
dann = train_dann(dann, params, src_data_loader, tgt_data_loader,
tgt_data_loader_eval, num_src_train, num_tgt_train, device, logger)
58 changes: 58 additions & 0 deletions utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,64 @@
def get_dataset_root():
return '/nobackup/yguo/datasets'

# data mode:
# 0. All one (uniform samplings)
# 1. First half 0, second half 1
# 2. First half 1, second half 0
# 3. Overlapping support: 0-6 --> 3-9
# 4. Mild random weight
# 5. Strong random weight


# run mode:
# 0. DANN
# 1. source weighting + method 1
# 2. target weighting + method 2
# 3. winsorization + method 3


def get_model_root(model_name, data_mode, run_mode):
data_mode = 'data{}'.format(data_mode)
run_mode = 'run{}'.format(run_mode)
model_root = os.path.expanduser(os.path.join('runs', model_name, data_mode, run_mode))
return model_root

def get_data(mode):
# Return a tuple of lists, where the first list corresponds to the source
# weight, and the second part corresponds to the target weight
if mode == 0:
source_weight = torch.tensor([])
target_weight = torch.ones(10)
return (source_weight, target_weight)
elif mode == 1:
source_weight = torch.tensor([])
target_weight = torch.tensor([0,0,0,0,0,1,1,1,1,1])
return (source_weight, target_weight)
elif mode == 2:
source_weight = torch.tensor([])
target_weight = torch.tensor([1,1,1,1,1,0,0,0,0,0])
return (source_weight, target_weight)
elif mode == 3:
source_weight = torch.tensor([1,1,1,1,1,1,0,0,0,0])
target_weight = torch.tensor([0,0,0,0,1,1,1,1,1,1])
elif mode == 4:
value = 0.25
source_weight = torch.tensor([])
target_weight = torch.tensor(np.concatenate(
[[value], np.random.uniform(value, 1-value, 8), [1-value]]))
return (source_weight, target_weight)
elif mode == 5:
value = 0.0625
source_weight = torch.tensor([])
target_weight = torch.tensor(np.concatenate(
[[value], np.random.uniform(value, 1-value, 8), [1-value]]))
return (source_weight, target_weight)
else:
source_weight = torch.tensor([])
target_weight = torch.tensor([0,0,0,0,0,0,0,0,0,1])
return (source_weight, target_weight)


def make_cuda(tensor):
"""Use CUDA if it's available."""
if torch.cuda.is_available():
Expand Down

0 comments on commit 579288f

Please sign in to comment.