diff --git a/data_utils.py b/data_utils.py index 0aa3d65..d454670 100644 --- a/data_utils.py +++ b/data_utils.py @@ -1,13 +1,15 @@ -from torch.utils.data import Dataset -from PIL import Image -from torchvision import datasets, transforms import os -import json + import torch +from PIL import Image +from torch.utils.data import Dataset +from torchvision import datasets, transforms + from random_erasing import RandomErasing + class ImageDataset(Dataset): - def __init__(self, imgs, transform = None): + def __init__(self, imgs, transform=None): self.imgs = imgs self.transform = transform @@ -15,11 +17,11 @@ def __len__(self): return len(self.imgs) def __getitem__(self, index): - data,label = self.imgs[index] + data, label = self.imgs[index] return self.transform(Image.open(data)), label -class Data(): +class Data: def __init__(self, datasets, data_dir, batch_size, erasing_p, color_jitter, train_all): self.datasets = datasets.split(',') self.batch_size = batch_size @@ -27,39 +29,48 @@ def __init__(self, datasets, data_dir, batch_size, erasing_p, color_jitter, trai self.color_jitter = color_jitter self.data_dir = data_dir self.train_all = '_all' if train_all else '' - + self.data_transforms = {} + self.train_loaders = {} + self.train_dataset_sizes = {} + self.train_class_sizes = {} + self.client_list = [] + self.test_loaders = {} + self.gallery_meta = {} + self.query_meta = {} + self.kd_loader = None + def transform(self): transform_train = [ - transforms.Resize((256,128), interpolation=3), - transforms.Pad(10), - transforms.RandomCrop((256,128)), - transforms.RandomHorizontalFlip(), - transforms.ToTensor(), - transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) - ] + transforms.Resize((256, 128), interpolation=3), + transforms.Pad(10), + transforms.RandomCrop((256, 128)), + transforms.RandomHorizontalFlip(), + transforms.ToTensor(), + transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) + ] transform_val = [ - transforms.Resize(size=(256,128),interpolation=3), - transforms.ToTensor(), - transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) - ] + transforms.Resize(size=(256, 128), interpolation=3), + transforms.ToTensor(), + transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) + ] if self.erasing_p > 0: transform_train = transform_train + [RandomErasing(probability=self.erasing_p, mean=[0.0, 0.0, 0.0])] if self.color_jitter: - transform_train = [transforms.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1, hue=0)] + transform_train + transform_train = [transforms.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1, hue=0)] \ + + transform_train self.data_transforms = { 'train': transforms.Compose(transform_train), 'val': transforms.Compose(transform_val), - } + } def preprocess_kd_data(self, dataset): loader, image_dataset = self.preprocess_one_train_dataset(dataset) self.kd_loader = loader - def preprocess_one_train_dataset(self, dataset): """preprocess a training dataset, construct a data loader. """ @@ -68,10 +79,10 @@ def preprocess_one_train_dataset(self, dataset): image_dataset = datasets.ImageFolder(data_path) loader = torch.utils.data.DataLoader( - ImageDataset(image_dataset.imgs, self.data_transforms['train']), + ImageDataset(image_dataset.imgs, self.data_transforms['train']), batch_size=self.batch_size, - shuffle=True, - num_workers=2, + shuffle=True, + num_workers=2, pin_memory=False) return loader, image_dataset @@ -79,64 +90,52 @@ def preprocess_one_train_dataset(self, dataset): def preprocess_train(self): """preprocess training data, constructing train loaders """ - self.train_loaders = {} - self.train_dataset_sizes = {} - self.train_class_sizes = {} - self.client_list = [] - for dataset in self.datasets: self.client_list.append(dataset) - + loader, image_dataset = self.preprocess_one_train_dataset(dataset) self.train_dataset_sizes[dataset] = len(image_dataset) self.train_class_sizes[dataset] = len(image_dataset.classes) self.train_loaders[dataset] = loader - + print('Train dataset sizes:', self.train_dataset_sizes) print('Train class sizes:', self.train_class_sizes) - + def preprocess_test(self): """preprocess testing data, constructing test loaders """ - self.test_loaders = {} - self.gallery_meta = {} - self.query_meta = {} - - for test_dir in self.datasets: - test_dir = 'data/'+test_dir+'/pytorch' - - dataset = test_dir.split('/')[1] + for dataset in self.datasets: + test_dir = os.path.join(self.data_dir, dataset, 'pytorch') gallery_dataset = datasets.ImageFolder(os.path.join(test_dir, 'gallery')) query_dataset = datasets.ImageFolder(os.path.join(test_dir, 'query')) - + gallery_dataset = ImageDataset(gallery_dataset.imgs, self.data_transforms['val']) query_dataset = ImageDataset(query_dataset.imgs, self.data_transforms['val']) self.test_loaders[dataset] = {key: torch.utils.data.DataLoader( - dataset, - batch_size=self.batch_size, - shuffle=False, - num_workers=8, - pin_memory=True) for key, dataset in {'gallery': gallery_dataset, 'query': query_dataset}.items()} - + dataset, + batch_size=self.batch_size, + shuffle=False, + num_workers=8, + pin_memory=True) for key, dataset in {'gallery': gallery_dataset, 'query': query_dataset}.items()} gallery_cameras, gallery_labels = get_camera_ids(gallery_dataset.imgs) self.gallery_meta[dataset] = { - 'sizes': len(gallery_dataset), + 'sizes': len(gallery_dataset), 'cameras': gallery_cameras, 'labels': gallery_labels } query_cameras, query_labels = get_camera_ids(query_dataset.imgs) self.query_meta[dataset] = { - 'sizes': len(query_dataset), + 'sizes': len(query_dataset), 'cameras': query_cameras, 'labels': query_labels } - print('Query Sizes:', self.query_meta[dataset]['sizes']) - print('Gallery Sizes:', self.gallery_meta[dataset]['sizes']) + print('Query Sizes:', self.query_meta[dataset]['sizes']) + print('Gallery Sizes:', self.gallery_meta[dataset]['sizes']) def preprocess(self): self.transform() @@ -144,6 +143,7 @@ def preprocess(self): self.preprocess_test() self.preprocess_kd_data('cuhk02') + def get_camera_ids(img_paths): """get camera id and labels by image path """ @@ -151,14 +151,14 @@ def get_camera_ids(img_paths): labels = [] for path, v in img_paths: filename = os.path.basename(path) - if filename[:3]!='cam': + if filename[:3] != 'cam': label = filename[0:4] camera = filename.split('c')[1] camera = camera.split('s')[0] else: label = filename.split('_')[2] camera = filename.split('_')[1] - if label[0:2]=='-1': + if label[0:2] == '-1': labels.append(-1) else: labels.append(int(label)) diff --git a/main.py b/main.py index f68f142..e5719ff 100644 --- a/main.py +++ b/main.py @@ -1,23 +1,15 @@ # -*- coding: utf-8 -*- from __future__ import print_function, division + import argparse -import torch -import time +import multiprocessing as mp import os -import yaml -import random -import numpy as np -import scipy.io -import pathlib import sys -import json -import copy -import multiprocessing as mp -import torch.nn.functional as F + import matplotlib +import torch + matplotlib.use('agg') -import matplotlib.pyplot as plt -from PIL import Image from client import Client from server import Server from utils import set_random_seed @@ -25,15 +17,17 @@ mp.set_start_method('spawn', force=True) sys.setrecursionlimit(10000) -version = torch.__version__ +version = torch.__version__ parser = argparse.ArgumentParser(description='Training') -parser.add_argument('--gpu_ids',default='0', type=str,help='gpu_ids: e.g. 0 0,1,2 0,2') -parser.add_argument('--model_name',default='ft_ResNet50', type=str, help='output model name') -parser.add_argument('--project_dir',default='.', type=str, help='project path') -parser.add_argument('--data_dir',default='data',type=str, help='training dir path') -parser.add_argument('--datasets',default='Market,DukeMTMC-reID,cuhk03-np-detected,cuhk01,MSMT17,viper,prid,3dpes,ilids',type=str, help='datasets used') -parser.add_argument('--train_all', action='store_true', help='use all training data' ) +parser.add_argument('--gpu_ids', default='0', type=str, help='gpu_ids: e.g. 0 0,1,2 0,2') +parser.add_argument('--model_name', default='ft_ResNet50', type=str, help='output model name') +parser.add_argument('--project_dir', default='.', type=str, help='project path') +parser.add_argument('--data_dir', default='data', type=str, help='training dir path') +parser.add_argument('--datasets', + default='Market,DukeMTMC-reID,cuhk03-np-detected,cuhk01,MSMT17,viper,prid,3dpes,ilids', type=str, + help='datasets used') +parser.add_argument('--train_all', action='store_true', default=True, help='use all training data') parser.add_argument('--stride', default=2, type=int, help='stride') parser.add_argument('--lr', default=0.05, type=float, help='learning rate') parser.add_argument('--drop_rate', default=0.5, type=float, help='drop rate') @@ -45,24 +39,25 @@ # arguments for data transformation parser.add_argument('--erasing_p', default=0, type=float, help='Random Erasing probability, in [0,1]') -parser.add_argument('--color_jitter', action='store_true', help='use color jitter in training' ) +parser.add_argument('--color_jitter', action='store_true', help='use color jitter in training') # arguments for testing federated model -parser.add_argument('--which_epoch',default='last', type=str, help='0,1,2,3...or last') -parser.add_argument('--multi', action='store_true', help='use multiple query' ) -parser.add_argument('--multiple_scale',default='1', type=str,help='multiple_scale: e.g. 1 1,1.1 1,1.1,1.2') -parser.add_argument('--test_dir',default='all',type=str, help='./test_data') +parser.add_argument('--which_epoch', default='last', type=str, help='0,1,2,3...or last') +parser.add_argument('--multi', action='store_true', help='use multiple query') +parser.add_argument('--multiple_scale', default='1', type=str, help='multiple_scale: e.g. 1 1,1.1 1,1.1,1.2') # arguments for optimization -parser.add_argument('--cdw', action='store_true', help='use cosine distance weight for model aggregation, default false' ) -parser.add_argument('--kd', action='store_true', help='apply knowledge distillation, default false' ) -parser.add_argument('--regularization', action='store_true', help='use regularization during distillation, default false' ) +parser.add_argument('--cdw', action='store_true', + help='use cosine distance weight for model aggregation, default false') +parser.add_argument('--kd', action='store_true', help='apply knowledge distillation, default false') +parser.add_argument('--regularization', action='store_true', + help='use regularization during distillation, default false') def train(): args = parser.parse_args() print(args) - + use_cuda = torch.cuda.is_available() device = torch.device("cuda" if use_cuda else "cpu") @@ -70,31 +65,31 @@ def train(): data = Data(args.datasets, args.data_dir, args.batch_size, args.erasing_p, args.color_jitter, args.train_all) data.preprocess() - + clients = {} for cid in data.client_list: clients[cid] = Client( - cid, - data, - device, - args.project_dir, - args.model_name, - args.local_epoch, - args.lr, - args.batch_size, - args.drop_rate, - args.stride) + cid, + data, + device, + args.project_dir, + args.model_name, + args.local_epoch, + args.lr, + args.batch_size, + args.drop_rate, + args.stride) server = Server( - clients, - data, - device, - args.project_dir, - args.model_name, - args.num_of_clients, - args.lr, - args.drop_rate, - args.stride, + clients, + data, + device, + args.project_dir, + args.model_name, + args.num_of_clients, + args.lr, + args.drop_rate, + args.stride, args.multiple_scale) dir_name = os.path.join(args.project_dir, 'model', args.model_name) @@ -104,22 +99,19 @@ def train(): print("=====training start!========") rounds = 800 for i in range(rounds): - print('='*10) + print('=' * 10) print("Round Number {}".format(i)) - print('='*10) + print('=' * 10) server.train(i, args.cdw, use_cuda) save_path = os.path.join(dir_name, 'federated_model.pth') torch.save(server.federated_model.cpu().state_dict(), save_path) - if (i+1)%10 == 0: + if (i + 1) % 10 == 0: server.test(use_cuda) if args.kd: server.knowledge_distillation(args.regularization) server.test(use_cuda) server.draw_curve() + if __name__ == '__main__': train() - - - -