Skip to content

Commit

Permalink
fix code
Browse files Browse the repository at this point in the history
  • Loading branch information
tensorsketch committed Jul 11, 2020
1 parent 2bcdae7 commit 1b76b53
Show file tree
Hide file tree
Showing 10 changed files with 151 additions and 60 deletions.
28 changes: 19 additions & 9 deletions datasets/mnistm_weight.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,11 @@ def __getitem__(self, item):

def __len__(self):
return self.n_data
def get_labels(self):
return torch.tensor([int(label) for label in self.img_labels])

def get_mnistm_weight(dataset_root, batch_size, train, weights = None):

def get_mnistm_weight(dataset_root, batch_size, train, weights):
"""Get MNISTM datasets loader."""
# image pre-processing, each image from MNIST-m has shape 32x32
pre_process = transforms.Compose([transforms.Resize(28),
Expand All @@ -64,13 +67,20 @@ def get_mnistm_weight(dataset_root, batch_size, train, weights = None):
data_list=train_list,
transform=pre_process)
num_sample = len(mnistm_dataset)
if sampler is not None:
sampler.num_samples = num_sample
mnistm_dataloader = torch.utils.data.DataLoader(
dataset=mnistm_dataset,
batch_size=batch_size,
shuffle=True,
sampler=sampler,
num_workers=8)
if len(weights) == 10:
sample_weight = torch.tensor([weights[label] for label in mnistm_dataset.get_labels()])
mnistm_dataloader = torch.utils.data.DataLoader(
dataset=mnistm_dataset,
batch_size=batch_size,
sampler=torch.utils.data.sampler.WeightedRandomSampler(
sample_weight,len(sample_weight)),
shuffle=True,
num_workers=8)
else:
mnistm_dataloader = torch.utils.data.DataLoader(
dataset=mnistm_dataset,
batch_size=batch_size,
shuffle=True,
num_workers=8)

return mnistm_dataloader, num_sample
7 changes: 4 additions & 3 deletions datasets/usps.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,10 @@ def get_usps(dataset_root, batch_size, train):
pre_process = transforms.Compose([transforms.Resize(28),
transforms.ToTensor(),
transforms.Normalize(
mean=[0.5],
std=[0.5]
)])
mean=[0.2473], # Mean for USPS train data
std=[0.2665] # std for USPS train data
)
])

# datasets and data loader
usps_dataset = datasets.USPS(root=os.path.join(dataset_root),
Expand Down
4 changes: 2 additions & 2 deletions datasets/usps_weight.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,8 @@ def get_usps_weight(dataset_root, batch_size, train, weights):
pre_process = transforms.Compose([transforms.Resize(28),
transforms.ToTensor(),
transforms.Normalize(
mean=[0.5],
std=[0.5]
mean=[0.2473], # Mean for USPS train data
std=[0.2665] # std for USPS train data
)])

# datasets and data loader
Expand Down
36 changes: 11 additions & 25 deletions mean.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
"cells": [
{
"cell_type": "code",
"execution_count": 2,
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -22,32 +22,18 @@
},
{
"cell_type": "code",
"execution_count": 3,
"execution_count": 4,
"metadata": {
"tags": []
},
"outputs": [
{
"output_type": "error",
"ename": "FileNotFoundError",
"evalue": "[Errno 2] No such file or directory: '/nobackup/yguo/datasets/gtsrb_train.p'",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mFileNotFoundError\u001b[0m Traceback (most recent call last)",
"\u001b[0;32m<ipython-input-3-d3dc473fda6c>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0mloader\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mget_gtsrb\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m'/nobackup/yguo/datasets'\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m10000\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;32mTrue\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
"\u001b[0;32m/nobackup/yguo/pytorch-dann/datasets/gtsrb.py\u001b[0m in \u001b[0;36mget_gtsrb\u001b[0;34m(dataset_root, batch_size, train)\u001b[0m\n\u001b[1;32m 46\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 47\u001b[0m \u001b[0;31m# datasets and data_loader\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 48\u001b[0;31m \u001b[0mgtsrb_dataset\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mGTSRB\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mos\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mpath\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mjoin\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdataset_root\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m'gtsrb_train.p'\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtransform\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mpre_process\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 49\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 50\u001b[0m \u001b[0mdataset_size\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mlen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mgtsrb_dataset\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m/nobackup/yguo/pytorch-dann/datasets/gtsrb.py\u001b[0m in \u001b[0;36m__init__\u001b[0;34m(self, filepath, transform)\u001b[0m\n\u001b[1;32m 13\u001b[0m \u001b[0;32mclass\u001b[0m \u001b[0mGTSRB\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdata\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mDataset\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 14\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0m__init__\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mfilepath\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtransform\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mNone\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 15\u001b[0;31m \u001b[0;32mwith\u001b[0m \u001b[0mopen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mfilepath\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m'rb'\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0mf\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 16\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdata\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mpickle\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mload\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mf\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 17\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mkeys\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0;34m'images'\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m'labels'\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;31mFileNotFoundError\u001b[0m: [Errno 2] No such file or directory: '/nobackup/yguo/datasets/gtsrb_train.p'"
]
}
],
"outputs": [],
"source": [
"loader = get_gtsrb('/nobackup/yguo/datasets', 10000, True)"
"loader = get_usps('/nobackup/yguo/datasets', 10, True)"
]
},
{
"cell_type": "code",
"execution_count": 10,
"execution_count": 5,
"metadata": {
"tags": []
},
Expand All @@ -69,16 +55,16 @@
},
{
"cell_type": "code",
"execution_count": 11,
"execution_count": 6,
"metadata": {},
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": "tensor([0.4376, 0.4438, 0.4729])"
"text/plain": "tensor([0.2473])"
},
"metadata": {},
"execution_count": 11
"execution_count": 6
}
],
"source": [
Expand All @@ -87,16 +73,16 @@
},
{
"cell_type": "code",
"execution_count": 12,
"execution_count": 7,
"metadata": {},
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": "tensor([0.1201, 0.1231, 0.1052])"
"text/plain": "tensor([0.2665])"
},
"metadata": {},
"execution_count": 12
"execution_count": 7
}
],
"source": [
Expand Down
11 changes: 6 additions & 5 deletions mnist_mnistm_weight.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,16 @@
sys.path.append('../')
from models.model import MNISTmodel, MNISTmodel_plain
from core.train_weight import train_dann
from utils.utils import get_data_loader, get_data_loader_weight, init_model, init_random_seed
from utils.utils import get_data_loader, get_data_loader_weight, init_model, init_random_seed, get_dataset_root
from torch.utils.tensorboard import SummaryWriter



class Config(object):
# params for path
model_name = "mnist-mnistm-weight"
dataset_root = os.path.expanduser('/nobackup/yguo/dataset')
model_root = os.path.expanduser(os.path.join('~', 'Models', 'pytorch-DANN', model_name))
dataset_root = get_dataset_root()
model_root = os.path.expanduser(os.path.join('runs', model_name))
finetune_flag = False

# params for datasets and data loader
Expand Down Expand Up @@ -63,7 +64,7 @@ class Config(object):

params = Config()

logger = SummaryWriter(params.model_root, flush_secs = 10)
logger = SummaryWriter(params.model_root)

# init random seed
init_random_seed(params.manual_seed)
Expand All @@ -86,7 +87,7 @@ class Config(object):


# load dann model
dann = init_model(net=MNISTmodel_plain(), restore=None)
dann = init_model(net=MNISTmodel(), restore=None)

# train dann model
print("Training dann model")
Expand Down
96 changes: 96 additions & 0 deletions mnist_mnistm_winsorize.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
import os
import sys

import torch
sys.path.append('../')
from models.model import MNISTmodel, MNISTmodel_plain
from core.train_weight import train_dann
from utils.utils import get_data_loader, get_data_loader_weight, init_model, init_random_seed, get_dataset_root
from torch.utils.tensorboard import SummaryWriter



class Config(object):
# params for path
model_name = "mnist-mnistm-weight"
dataset_root = get_dataset_root()
model_root = os.path.expanduser(os.path.join('runs', model_name))
finetune_flag = False

# params for datasets and data loader
batch_size = 64

# params for source dataset
src_dataset = "mnist"
src_model_trained = True
src_classifier_restore = os.path.join(model_root, src_dataset + '-source-classifier-final.pt')
class_num_src = 31

# params for target dataset
tgt_dataset = "mnistm"
tgt_model_trained = True
dann_restore = os.path.join(model_root, src_dataset + '-' + tgt_dataset + '-dann-final.pt')

# params for pretrain
num_epochs_src = 100
log_step_src = 10
save_step_src = 50
eval_step_src = 20

# params for training dann
gpu_id = '0'

## for digit
num_epochs = 100
log_step = 20
save_step = 50
eval_step = 1

## for office
# num_epochs = 1000
# log_step = 10 # iters
# save_step = 500
# eval_step = 5 # epochs
lr_adjust_flag = 'simple'
src_only_flag = False

manual_seed = 8888
alpha = 0

# params for optimizing models
lr = 5e-4
momentum = 0
weight_decay = 0

params = Config()

logger = SummaryWriter(params.model_root)

# init random seed
init_random_seed(params.manual_seed)

# init device
device = torch.device("cuda:" + params.gpu_id if torch.cuda.is_available() else "cpu")


WEIGHTS = torch.ones(10)
# load dataset
src_data_loader, num_src_train = get_data_loader_weight(params.src_dataset, params.dataset_root, params.batch_size, train=True)
src_data_loader_eval = get_data_loader(params.src_dataset, params.dataset_root, params.batch_size, train=False)
tgt_data_loader, num_tgt_train = get_data_loader_weight(
params.tgt_dataset, params.dataset_root, params.batch_size, train=True, sampler=torch.utils.data.sampler.WeightedRandomSampler(
WEIGHTS, 1))
tgt_data_loader_eval, _ = get_data_loader_weight(
params.tgt_dataset, params.dataset_root, params.batch_size, train=False, sampler=torch.utils.data.sampler.WeightedRandomSampler(
WEIGHTS, 1))
# Cannot use the same sampler for both training and testing dataset


# load dann model
dann = init_model(net=MNISTmodel_plain(), restore=None)

# train dann model
print("Training dann model")
if not (dann.restored and params.dann_restore):
dann = train_dann(dann, params, src_data_loader, tgt_data_loader,
tgt_data_loader_eval, num_src_train, num_tgt_train, device, logger)
6 changes: 3 additions & 3 deletions mnist_usps.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,15 @@
sys.path.append('../')
from models.model import MNISTmodel, MNISTmodel_plain
from core.train import train_dann
from utils.utils import get_data_loader, init_model, init_random_seed
from utils.utils import get_data_loader, init_model, init_random_seed, get_dataset_root
from torch.utils.tensorboard import SummaryWriter


class Config(object):
# params for path
model_name = "mnist-usps"
dataset_root = os.path.expanduser('/nobackup/yguo/dataset')
model_root = os.path.expanduser(os.path.join('~', 'Models', 'pytorch-DANN', model_name))
dataset_root = get_dataset_root()
model_root = os.path.expanduser(os.path.join('runs', model_name))
finetune_flag = False

# params for datasets and data loader
Expand Down
11 changes: 4 additions & 7 deletions svhn_mnist.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,16 +7,13 @@
sys.path.append('../')
from models.model import SVHNmodel
from core.train import train_dann
from utils.utils import get_data_loader, init_model, init_random_seed
from utils.utils import get_data_loader, init_model, init_random_seed, get_dataset_root


class Config(object):
# params for path
model_name = "svhn-mnist"
model_base = '/nobackup/yguo/pytorch-dann'
model_root = os.path.expanduser(os.path.join('~', 'Models', 'pytorch-DANN', model_name))
note = 'paper-structure'
model_root = os.path.join(model_base, model_name, note + '_' + datetime.datetime.now().strftime('%m%d_%H%M%S'))
model_root = os.path.expanduser(os.path.join('runs', model_name))
os.makedirs(model_root)
config = os.path.join(model_root, 'config.txt')
finetune_flag = False
Expand All @@ -28,13 +25,13 @@ class Config(object):

# params for source dataset
src_dataset = "svhn"
src_image_root = os.path.join('/nobackup/yguo/dataset', 'svhn')
src_image_root = get_dataset_root()
src_model_trained = True
src_classifier_restore = os.path.join(model_root, src_dataset + '-source-classifier-final.pt')

# params for target dataset
tgt_dataset = "mnist"
tgt_image_root = os.path.join('/nobackup/yguo/dataset', 'mnist')
tgt_image_root = get_dataset_root()
tgt_model_trained = True
dann_restore = os.path.join(model_root, src_dataset + '-' + tgt_dataset + '-dann-final.pt')

Expand Down
8 changes: 3 additions & 5 deletions svhn_mnist_weight.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,15 +7,13 @@
sys.path.append('../')
from models.model import SVHNmodel
from core.train_weight import train_dann
from utils.utils import get_data_loader, get_data_loader_weight, init_model, init_random_seed
from utils.utils import get_data_loader, get_data_loader_weight, init_model, init_random_seed, get_dataset_root
import numpy as np

class Config(object):
# params for path
model_name = "svhn-mnist-weight"
model_base = '/nobackup/yguo/pytorch-dann'
model_root = os.path.expanduser(os.path.join('runs', model_name))
model_root = os.path.join(model_base, model_name, '_' + datetime.datetime.now().strftime('%m%d_%H%M%S'))
os.makedirs(model_root)
config = os.path.join(model_root, 'config.txt')
finetune_flag = False
Expand All @@ -27,13 +25,13 @@ class Config(object):

# params for source dataset
src_dataset = "svhn"
src_image_root = os.path.join('/nobackup/yguo/dataset', 'svhn')
src_image_root = get_dataset_root()
src_model_trained = True
src_classifier_restore = os.path.join(model_root, src_dataset + '-source-classifier-final.pt')

# params for target dataset
tgt_dataset = "mnist"
tgt_image_root = os.path.join('/nobackup/yguo/dataset', 'mnist')
tgt_image_root = get_dataset_root()
tgt_model_trained = True
dann_restore = os.path.join(model_root, src_dataset + '-' + tgt_dataset + '-dann-final.pt')

Expand Down
4 changes: 3 additions & 1 deletion utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,9 @@ def get_data_loader(name, dataset_root, batch_size, train=True):


def get_data_loader_weight(name, dataset_root, batch_size, train=True, weights = torch.tensor([])):
"""Get data loader by name."""
"""Get data loader by name. If len(weights) is 0 (default), no weighted
sampling is performed.
"""
if name == "mnist":
return get_mnist_weight(dataset_root, batch_size, train, weights = weights)
elif name == "mnistm":
Expand Down

0 comments on commit 1b76b53

Please sign in to comment.