Skip to content

Commit

Permalink
feat[test]: update path to load test data.
Browse files Browse the repository at this point in the history
  • Loading branch information
weimingwill committed Apr 21, 2021
1 parent b63d990 commit a52e8f4
Show file tree
Hide file tree
Showing 2 changed files with 102 additions and 110 deletions.
108 changes: 54 additions & 54 deletions data_utils.py
Original file line number Diff line number Diff line change
@@ -1,65 +1,76 @@
from torch.utils.data import Dataset
from PIL import Image
from torchvision import datasets, transforms
import os
import json

import torch
from PIL import Image
from torch.utils.data import Dataset
from torchvision import datasets, transforms

from random_erasing import RandomErasing


class ImageDataset(Dataset):
def __init__(self, imgs, transform = None):
def __init__(self, imgs, transform=None):
self.imgs = imgs
self.transform = transform

def __len__(self):
return len(self.imgs)

def __getitem__(self, index):
data,label = self.imgs[index]
data, label = self.imgs[index]
return self.transform(Image.open(data)), label


class Data():
class Data:
def __init__(self, datasets, data_dir, batch_size, erasing_p, color_jitter, train_all):
self.datasets = datasets.split(',')
self.batch_size = batch_size
self.erasing_p = erasing_p
self.color_jitter = color_jitter
self.data_dir = data_dir
self.train_all = '_all' if train_all else ''

self.data_transforms = {}
self.train_loaders = {}
self.train_dataset_sizes = {}
self.train_class_sizes = {}
self.client_list = []
self.test_loaders = {}
self.gallery_meta = {}
self.query_meta = {}
self.kd_loader = None

def transform(self):
transform_train = [
transforms.Resize((256,128), interpolation=3),
transforms.Pad(10),
transforms.RandomCrop((256,128)),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
]
transforms.Resize((256, 128), interpolation=3),
transforms.Pad(10),
transforms.RandomCrop((256, 128)),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
]

transform_val = [
transforms.Resize(size=(256,128),interpolation=3),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
]
transforms.Resize(size=(256, 128), interpolation=3),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
]

if self.erasing_p > 0:
transform_train = transform_train + [RandomErasing(probability=self.erasing_p, mean=[0.0, 0.0, 0.0])]

if self.color_jitter:
transform_train = [transforms.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1, hue=0)] + transform_train
transform_train = [transforms.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1, hue=0)] \
+ transform_train

self.data_transforms = {
'train': transforms.Compose(transform_train),
'val': transforms.Compose(transform_val),
}
}

def preprocess_kd_data(self, dataset):
loader, image_dataset = self.preprocess_one_train_dataset(dataset)
self.kd_loader = loader


def preprocess_one_train_dataset(self, dataset):
"""preprocess a training dataset, construct a data loader.
"""
Expand All @@ -68,97 +79,86 @@ def preprocess_one_train_dataset(self, dataset):
image_dataset = datasets.ImageFolder(data_path)

loader = torch.utils.data.DataLoader(
ImageDataset(image_dataset.imgs, self.data_transforms['train']),
ImageDataset(image_dataset.imgs, self.data_transforms['train']),
batch_size=self.batch_size,
shuffle=True,
num_workers=2,
shuffle=True,
num_workers=2,
pin_memory=False)

return loader, image_dataset

def preprocess_train(self):
"""preprocess training data, constructing train loaders
"""
self.train_loaders = {}
self.train_dataset_sizes = {}
self.train_class_sizes = {}
self.client_list = []

for dataset in self.datasets:
self.client_list.append(dataset)

loader, image_dataset = self.preprocess_one_train_dataset(dataset)

self.train_dataset_sizes[dataset] = len(image_dataset)
self.train_class_sizes[dataset] = len(image_dataset.classes)
self.train_loaders[dataset] = loader

print('Train dataset sizes:', self.train_dataset_sizes)
print('Train class sizes:', self.train_class_sizes)

def preprocess_test(self):
"""preprocess testing data, constructing test loaders
"""
self.test_loaders = {}
self.gallery_meta = {}
self.query_meta = {}

for test_dir in self.datasets:
test_dir = 'data/'+test_dir+'/pytorch'

dataset = test_dir.split('/')[1]
for dataset in self.datasets:
test_dir = os.path.join(self.data_dir, dataset, 'pytorch')
gallery_dataset = datasets.ImageFolder(os.path.join(test_dir, 'gallery'))
query_dataset = datasets.ImageFolder(os.path.join(test_dir, 'query'))

gallery_dataset = ImageDataset(gallery_dataset.imgs, self.data_transforms['val'])
query_dataset = ImageDataset(query_dataset.imgs, self.data_transforms['val'])

self.test_loaders[dataset] = {key: torch.utils.data.DataLoader(
dataset,
batch_size=self.batch_size,
shuffle=False,
num_workers=8,
pin_memory=True) for key, dataset in {'gallery': gallery_dataset, 'query': query_dataset}.items()}

dataset,
batch_size=self.batch_size,
shuffle=False,
num_workers=8,
pin_memory=True) for key, dataset in {'gallery': gallery_dataset, 'query': query_dataset}.items()}

gallery_cameras, gallery_labels = get_camera_ids(gallery_dataset.imgs)
self.gallery_meta[dataset] = {
'sizes': len(gallery_dataset),
'sizes': len(gallery_dataset),
'cameras': gallery_cameras,
'labels': gallery_labels
}

query_cameras, query_labels = get_camera_ids(query_dataset.imgs)
self.query_meta[dataset] = {
'sizes': len(query_dataset),
'sizes': len(query_dataset),
'cameras': query_cameras,
'labels': query_labels
}

print('Query Sizes:', self.query_meta[dataset]['sizes'])
print('Gallery Sizes:', self.gallery_meta[dataset]['sizes'])
print('Query Sizes:', self.query_meta[dataset]['sizes'])
print('Gallery Sizes:', self.gallery_meta[dataset]['sizes'])

def preprocess(self):
self.transform()
self.preprocess_train()
self.preprocess_test()
self.preprocess_kd_data('cuhk02')


def get_camera_ids(img_paths):
"""get camera id and labels by image path
"""
camera_ids = []
labels = []
for path, v in img_paths:
filename = os.path.basename(path)
if filename[:3]!='cam':
if filename[:3] != 'cam':
label = filename[0:4]
camera = filename.split('c')[1]
camera = camera.split('s')[0]
else:
label = filename.split('_')[2]
camera = filename.split('_')[1]
if label[0:2]=='-1':
if label[0:2] == '-1':
labels.append(-1)
else:
labels.append(int(label))
Expand Down
104 changes: 48 additions & 56 deletions main.py
Original file line number Diff line number Diff line change
@@ -1,39 +1,33 @@
# -*- coding: utf-8 -*-
from __future__ import print_function, division

import argparse
import torch
import time
import multiprocessing as mp
import os
import yaml
import random
import numpy as np
import scipy.io
import pathlib
import sys
import json
import copy
import multiprocessing as mp
import torch.nn.functional as F

import matplotlib
import torch

matplotlib.use('agg')
import matplotlib.pyplot as plt
from PIL import Image
from client import Client
from server import Server
from utils import set_random_seed
from data_utils import Data

mp.set_start_method('spawn', force=True)
sys.setrecursionlimit(10000)
version = torch.__version__
version = torch.__version__

parser = argparse.ArgumentParser(description='Training')
parser.add_argument('--gpu_ids',default='0', type=str,help='gpu_ids: e.g. 0 0,1,2 0,2')
parser.add_argument('--model_name',default='ft_ResNet50', type=str, help='output model name')
parser.add_argument('--project_dir',default='.', type=str, help='project path')
parser.add_argument('--data_dir',default='data',type=str, help='training dir path')
parser.add_argument('--datasets',default='Market,DukeMTMC-reID,cuhk03-np-detected,cuhk01,MSMT17,viper,prid,3dpes,ilids',type=str, help='datasets used')
parser.add_argument('--train_all', action='store_true', help='use all training data' )
parser.add_argument('--gpu_ids', default='0', type=str, help='gpu_ids: e.g. 0 0,1,2 0,2')
parser.add_argument('--model_name', default='ft_ResNet50', type=str, help='output model name')
parser.add_argument('--project_dir', default='.', type=str, help='project path')
parser.add_argument('--data_dir', default='data', type=str, help='training dir path')
parser.add_argument('--datasets',
default='Market,DukeMTMC-reID,cuhk03-np-detected,cuhk01,MSMT17,viper,prid,3dpes,ilids', type=str,
help='datasets used')
parser.add_argument('--train_all', action='store_true', default=True, help='use all training data')
parser.add_argument('--stride', default=2, type=int, help='stride')
parser.add_argument('--lr', default=0.05, type=float, help='learning rate')
parser.add_argument('--drop_rate', default=0.5, type=float, help='drop rate')
Expand All @@ -45,56 +39,57 @@

# arguments for data transformation
parser.add_argument('--erasing_p', default=0, type=float, help='Random Erasing probability, in [0,1]')
parser.add_argument('--color_jitter', action='store_true', help='use color jitter in training' )
parser.add_argument('--color_jitter', action='store_true', help='use color jitter in training')

# arguments for testing federated model
parser.add_argument('--which_epoch',default='last', type=str, help='0,1,2,3...or last')
parser.add_argument('--multi', action='store_true', help='use multiple query' )
parser.add_argument('--multiple_scale',default='1', type=str,help='multiple_scale: e.g. 1 1,1.1 1,1.1,1.2')
parser.add_argument('--test_dir',default='all',type=str, help='./test_data')
parser.add_argument('--which_epoch', default='last', type=str, help='0,1,2,3...or last')
parser.add_argument('--multi', action='store_true', help='use multiple query')
parser.add_argument('--multiple_scale', default='1', type=str, help='multiple_scale: e.g. 1 1,1.1 1,1.1,1.2')

# arguments for optimization
parser.add_argument('--cdw', action='store_true', help='use cosine distance weight for model aggregation, default false' )
parser.add_argument('--kd', action='store_true', help='apply knowledge distillation, default false' )
parser.add_argument('--regularization', action='store_true', help='use regularization during distillation, default false' )
parser.add_argument('--cdw', action='store_true',
help='use cosine distance weight for model aggregation, default false')
parser.add_argument('--kd', action='store_true', help='apply knowledge distillation, default false')
parser.add_argument('--regularization', action='store_true',
help='use regularization during distillation, default false')


def train():
args = parser.parse_args()
print(args)

use_cuda = torch.cuda.is_available()
device = torch.device("cuda" if use_cuda else "cpu")

set_random_seed(1)

data = Data(args.datasets, args.data_dir, args.batch_size, args.erasing_p, args.color_jitter, args.train_all)
data.preprocess()

clients = {}
for cid in data.client_list:
clients[cid] = Client(
cid,
data,
device,
args.project_dir,
args.model_name,
args.local_epoch,
args.lr,
args.batch_size,
args.drop_rate,
args.stride)
cid,
data,
device,
args.project_dir,
args.model_name,
args.local_epoch,
args.lr,
args.batch_size,
args.drop_rate,
args.stride)

server = Server(
clients,
data,
device,
args.project_dir,
args.model_name,
args.num_of_clients,
args.lr,
args.drop_rate,
args.stride,
clients,
data,
device,
args.project_dir,
args.model_name,
args.num_of_clients,
args.lr,
args.drop_rate,
args.stride,
args.multiple_scale)

dir_name = os.path.join(args.project_dir, 'model', args.model_name)
Expand All @@ -104,22 +99,19 @@ def train():
print("=====training start!========")
rounds = 800
for i in range(rounds):
print('='*10)
print('=' * 10)
print("Round Number {}".format(i))
print('='*10)
print('=' * 10)
server.train(i, args.cdw, use_cuda)
save_path = os.path.join(dir_name, 'federated_model.pth')
torch.save(server.federated_model.cpu().state_dict(), save_path)
if (i+1)%10 == 0:
if (i + 1) % 10 == 0:
server.test(use_cuda)
if args.kd:
server.knowledge_distillation(args.regularization)
server.test(use_cuda)
server.draw_curve()


if __name__ == '__main__':
train()




0 comments on commit a52e8f4

Please sign in to comment.