forked from yoyohonyang/LearingFaceAgeProgression
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtest_aging.lua
executable file
·81 lines (71 loc) · 2.22 KB
/
test_aging.lua
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
require 'torch'
require 'image'
require 'nn'
require 'cutorch'
require 'FaceAging.DataLoader_aging_test'
require 'FaceAging.ShaveImage'
require 'FaceAging.TotalVariation'
require 'FaceAging.InstanceNormalization'
local utils = require 'FaceAging.utils'
local preprocess = require 'FaceAging.preprocess'
local cmd = torch.CmdLine()
-- Model options
cmd:option('-model', '.models/CACD_Aging.t7')
cmd:option('-h5_file', 'data/CACD/CACD_test.h5')
cmd:option('-use_instance_norm', 1)
cmd:option('-preprocessing', 'vgg')
cmd:option('-gpu', 1)
cmd:option('-backend', 'cuda')
cmd:option('-output_dir', './CACD_aging_results')
cmd:option('-output_filename', 'aging_results')
cmd:option('-num_val_batches', 25)
cmd:option('-batch_size', 8)
local function main()
local opt = cmd:parse(arg)
paths.mkdir(opt.output_dir)
-- Figure out preprocessing
if not preprocess[opt.preprocessing] then
local msg = 'invalid -preprocessing "%s"; must be "vgg" or "resnet"'
error(string.format(msg, opt.preprocessing))
end
preprocess = preprocess[opt.preprocessing]
local loader = DataLoader_aging_test(opt)
local dtype, use_cudnn = utils.setup_gpu(opt.gpu, opt.backend, opt.use_cudnn == 1)
cutorch.setDevice(1)
-- Set up the generator
local netG = nil
if opt.model ~= '' then
print('Loading model from ' .. opt.model)
netG = torch.load(opt.model).netG:type(dtype)
else
print('No such file!')
end
loader:reset('val')
--- Warm up the GPU
if opt.use_cudnn then
print('Warm up the GPU...')
for t1 = 1, 10 do
local x1 = torch.randn(opt.batch_size, 3, 224, 224):type(dtype)
local y1 = netG:forward(x1)
end
end
--- Testing
netG:evaluate()
print('Testing...')
for t2 = 1, opt.num_val_batches do
print(t2)
image_out_final = nil
local x = loader:getBatch('val')
x = x:type(dtype)
local out = netG:forward(x)
out = preprocess.deprocess(out)
input = preprocess.deprocess(x)
image_out = nil
for i2=1, out:size(1) do
if image_out==nil then image_out = torch.cat(input[i2]:float(),out[i2]:float(),3)
else image_out = torch.cat(image_out, torch.cat(input[i2]:float(),out[i2]:float(),3), 2) end
end
image.save(paths.concat(opt.output_dir, opt.output_filename ..'_'.. t2 .. '_test_res.png'), image_out)
end
end
main()