forked from dk-liang/CLTR
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathutils.py
109 lines (79 loc) · 3.05 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
import logging
import os
import random
import shutil
import cv2
import h5py
import numpy as np
import torch
from mmcv.utils import get_logger
from tqdm import tqdm
from image import load_data, load_data_test
def save_results(input_img, gt_data, density_map, output_dir, fname='results.png'):
density_map[density_map < 0] = 0
gt_data = 255 * gt_data / np.max(gt_data)
gt_data = gt_data[0][0]
gt_data = gt_data.astype(np.uint8)
gt_data = cv2.applyColorMap(gt_data, 2)
density_map = 255 * density_map / np.max(density_map)
density_map = density_map[0][0]
density_map = density_map.astype(np.uint8)
density_map = cv2.applyColorMap(density_map, 2)
result_img = np.hstack((gt_data, density_map))
cv2.imwrite(os.path.join('.', output_dir, fname).replace('.jpg', '.jpg'), result_img)
def save_net(fname, net):
with h5py.File(fname, 'w') as h5f:
for k, v in net.state_dict().items():
h5f.create_dataset(k, data=v.cpu().numpy())
def load_net(fname, net):
with h5py.File(fname, 'r') as h5f:
for k, v in net.state_dict().items():
param = torch.from_numpy(np.asarray(h5f[k]))
v.copy_(param)
def save_checkpoint(state, visi, is_best, save_path, filename='checkpoint.pth'):
torch.save(state, './' + str(save_path) + '/' + filename)
if is_best:
shutil.copyfile('./' + str(save_path) + '/' + filename, './' + str(save_path) + '/' + 'model_best.pth')
for i in range(len(visi)):
img = visi[i][0]
output = visi[i][1]
target = visi[i][2]
fname = visi[i][3]
save_results(img, target, output, str(save_path), fname[0])
def setup_seed(seed):
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed) # if you are using multi-GPU.
np.random.seed(seed)
random.seed(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False # 输入固定情况下用true
def get_root_logger(log_file=None, log_level=logging.INFO):
"""Get root logger.
Args:
log_file (str, optional): File path of log. Defaults to None.
log_level (int, optional): The level of logger.
Defaults to logging.INFO.
Returns:
:obj:`logging.Logger`: The obtained logger
"""
logger = get_logger(name='CLTR', log_file=log_file, log_level=log_level)
return logger
def pre_data_test(train_list, args, train):
data_keys = {}
count = 0
for j in tqdm(range(len(train_list))):
Img_path = train_list[j]
# Img_path = '../datasets/NWPU_localization/val_data_ori/3234.jpg'
fname = os.path.basename(Img_path)
# print(fname)
img = load_data_test(Img_path, args, train)
# print(img.size, fidt_map.shape)
blob = {}
blob['img'] = img
blob['fname'] = fname
data_keys[count] = blob
count += 1
# if count >10:
# break
return data_keys