Skip to content

Commit 4725c6b

Browse files
committedJun 26, 2014
Added SpatialUpSamplingNearest module.
1 parent 1310a04 commit 4725c6b

File tree

6 files changed

+308
-0
lines changed

6 files changed

+308
-0
lines changed
 

‎SpatialUpSamplingNearest.lua

+58
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
local SpatialUpSamplingNearest, parent = torch.class('nn.SpatialUpSamplingNearest', 'nn.Module')
2+
3+
--[[
4+
Applies a 2D up-sampling over an input image composed of several input planes.
5+
6+
The upsampling is done using the simple nearest neighbor technique.
7+
8+
The Y and X dimensions are assumed to be the last 2 tensor dimensions. For
9+
instance, if the tensor is 4D, then dim 3 is the y dimension and dim 4 is the x.
10+
11+
owidth = width*scale_factor
12+
oheight = height*scale_factor
13+
--]]
14+
15+
function SpatialUpSamplingNearest:__init(scale)
16+
parent.__init(self)
17+
18+
self.scale_factor = scale
19+
if self.scale_factor < 1 then
20+
error('scale_factor must be greater than 1')
21+
end
22+
if math.floor(self.scale_factor) ~= self.scale_factor then
23+
error('scale_factor must be integer')
24+
end
25+
self.inputSize = torch.LongStorage(4)
26+
self.outputSize = torch.LongStorage(4)
27+
self.usage = nil
28+
end
29+
30+
function SpatialUpSamplingNearest:updateOutput(input)
31+
if input:dim() ~= 4 and input:dim() ~= 3 then
32+
error('SpatialUpSamplingNearest only support 3D or 4D tensors')
33+
end
34+
-- Copy the input size
35+
local xdim = input:dim()
36+
local ydim = input:dim() - 1
37+
for i = 1, input:dim() do
38+
self.inputSize[i] = input:size(i)
39+
self.outputSize[i] = input:size(i)
40+
end
41+
self.outputSize[ydim] = self.outputSize[ydim] * self.scale_factor
42+
self.outputSize[xdim] = self.outputSize[xdim] * self.scale_factor
43+
-- Resize the output if needed
44+
if input:dim() == 3 then
45+
self.output:resize(self.outputSize[1], self.outputSize[2],
46+
self.outputSize[3])
47+
else
48+
self.output:resize(self.outputSize)
49+
end
50+
input.nn.SpatialUpSamplingNearest_updateOutput(self, input)
51+
return self.output
52+
end
53+
54+
function SpatialUpSamplingNearest:updateGradInput(input, gradOutput)
55+
self.gradInput:resizeAs(input)
56+
input.nn.SpatialUpSamplingNearest_updateGradInput(self, input, gradOutput)
57+
return self.gradInput
58+
end

‎doc/convolution.md

100644100755
+21
Original file line numberDiff line numberDiff line change
@@ -396,6 +396,27 @@ output[i][j][k] = bias[k]
396396
+ weight[k] sum_{s=1}^kW sum_{t=1}^kH input[dW*(i-1)+s)][dH*(j-1)+t][k]
397397
```
398398

399+
<a name="nn.SpatialUpSamplingNearest"/>
400+
### SpatialUpSamplingNearest ###
401+
402+
```lua
403+
module = nn.SpatialUpSamplingNearest(scale)
404+
```
405+
406+
Applies a 2D up-sampling over an input image composed of several input planes. The `input` tensor in
407+
`forward(input)` is expected to be a 3D or 4D tensor (i.e. for 4D: `nBatchPlane x nInputPlane x height x width`). The number of output planes will be the same. The v dimension is assumed to be the second last dimension (i.e. for 4D it will be the 3rd dim), and the u dimension is assumed to be the last dimension.
408+
409+
The parameters are the following:
410+
* `scale`: The upscale ratio. Must be a positive integer
411+
412+
The up-scaling method is simple nearest neighbor, ie:
413+
414+
```lua
415+
output(u,v) = input(floor((u-1)/scale)+1, floor((v-1)/scale)+1)
416+
```
417+
418+
Where `u` and `v` are index from 1 (as per lua convention). There are no learnable parameters.
419+
399420
<a name="nn.SpatialZeroPadding"/>
400421
### SpatialZeroPadding ###
401422

‎generic/SpatialUpSamplingNearest.c

+159
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,159 @@
1+
#ifndef TH_GENERIC_FILE
2+
#define TH_GENERIC_FILE "generic/SpatialUpSamplingNearest.c"
3+
#else
4+
5+
static int nn_(SpatialUpSamplingNearest_updateOutput)(lua_State *L)
6+
{
7+
// get all params
8+
THTensor *input = luaT_checkudata(L, 2, torch_Tensor);
9+
int scale_factor = luaT_getfieldcheckint(L, 1, "scale_factor");
10+
int dW = scale_factor;
11+
int dH = scale_factor;
12+
int xDim = input->nDimension-2;
13+
int yDim = input->nDimension-1;
14+
THTensor *output = luaT_getfieldcheckudata(L, 1, "output", torch_Tensor);
15+
16+
// dims
17+
int idim = input->nDimension; // Gauranteed to be between 3 and 5
18+
int osz0 = output->size[0];
19+
int osz1 = output->size[1];
20+
int osz2 = output->size[2];
21+
int osz3 = 1;
22+
if (idim > 3) {
23+
osz3 = output->size[3];
24+
}
25+
26+
// get strides
27+
long *is = input->stride;
28+
long *os = output->stride;
29+
30+
// get raw pointers
31+
real *pin = THTensor_(data)(input);
32+
real *pout = THTensor_(data)(output);
33+
34+
// perform the upsampling
35+
int i0, i1, i2, i3, isrc, idst;
36+
int iout[4]; // Output indices
37+
int iin[4]; // Input indices
38+
39+
for (i0 = 0; i0 < osz0; i0++) {
40+
iout[0] = i0;
41+
iin[0] = i0;
42+
for (i1 = 0; i1 < osz1; i1++) {
43+
iout[1] = i1;
44+
iin[1] = i1;
45+
for (i2 = 0; i2 < osz2; i2++) {
46+
iout[2] = i2;
47+
iin[2] = i2;
48+
for (i3 = 0; i3 < osz3; i3++) {
49+
iout[3] = i3;
50+
iin[3] = i3;
51+
52+
// set the indices for the upsampled dimensions
53+
iin[xDim] = iout[xDim] / dW;
54+
iin[yDim] = iout[yDim] / dH;
55+
56+
idst = i0*os[0] + i1*os[1] + i2*os[2];
57+
isrc = iin[0]*is[0] + iin[1]*is[1] + iin[2]*is[2];
58+
if (idim > 3) {
59+
idst += i3*os[3];
60+
isrc += iin[3]*is[3];
61+
}
62+
63+
pout[idst] = pin[isrc];
64+
}
65+
}
66+
}
67+
}
68+
return 1;
69+
}
70+
71+
static int nn_(SpatialUpSamplingNearest_updateGradInput)(lua_State *L)
72+
{
73+
// get all params
74+
//THTensor *input = luaT_checkudata(L,2, torch_Tensor);
75+
THTensor *gradOutput = luaT_checkudata(L,3, torch_Tensor);
76+
THTensor *gradInput = luaT_getfieldcheckudata(L,1, "gradInput", torch_Tensor);
77+
78+
int scale_factor = luaT_getfieldcheckint(L, 1, "scale_factor");
79+
int dW = scale_factor;
80+
int dH = scale_factor;
81+
int xDim = gradInput->nDimension-2;
82+
int yDim = gradInput->nDimension-1;
83+
84+
// dims
85+
int idim = gradInput->nDimension; // Gauranteed to be between 3 and 5
86+
int isz0 = gradInput->size[0];
87+
int isz1 = gradInput->size[1];
88+
int isz2 = gradInput->size[2];
89+
int isz3 = 1;
90+
if (idim > 3) {
91+
isz3 = gradInput->size[3];
92+
}
93+
94+
// get strides
95+
long *is = gradInput->stride;
96+
long *os = gradOutput->stride;
97+
98+
// get raw pointers
99+
real *pin = THTensor_(data)(gradInput);
100+
real *pout = THTensor_(data)(gradOutput);
101+
102+
// perform the upsampling
103+
int i0, i1, i2, i3, isrc, idst, x, y;
104+
int iin[4]; // Input indices
105+
int iout[4]; // Output indices
106+
107+
THTensor_(zero)(gradInput);
108+
109+
for (i0 = 0; i0 < isz0; i0++) {
110+
iin[0] = i0;
111+
iout[0] = i0;
112+
for (i1 = 0; i1 < isz1; i1++) {
113+
iin[1] = i1;
114+
iout[1] = i1;
115+
for (i2 = 0; i2 < isz2; i2++) {
116+
iin[2] = i2;
117+
iout[2] = i2;
118+
for (i3 = 0; i3 < isz3; i3++) {
119+
iin[3] = i3;
120+
iout[3] = i3;
121+
122+
idst = i0*is[0] + i1*is[1] + i2*is[2];
123+
if (idim > 3) {
124+
idst += i3*is[3];
125+
}
126+
127+
// Now accumulate the gradients from gradOutput
128+
for (y = 0; y < dH; y++) {
129+
for (x = 0; x < dW; x++) {
130+
iout[xDim] = dW * iin[xDim] + x;
131+
iout[yDim] = dH * iin[yDim] + y;
132+
isrc = iout[0]*os[0] + iout[1]*os[1] + iout[2]*os[2];
133+
if (idim > 3) {
134+
isrc += iout[3]*os[3];
135+
}
136+
pin[idst] += pout[isrc];
137+
}
138+
}
139+
}
140+
}
141+
}
142+
}
143+
return 1;
144+
}
145+
146+
static const struct luaL_Reg nn_(SpatialUpSamplingNearest__) [] = {
147+
{"SpatialUpSamplingNearest_updateOutput", nn_(SpatialUpSamplingNearest_updateOutput)},
148+
{"SpatialUpSamplingNearest_updateGradInput", nn_(SpatialUpSamplingNearest_updateGradInput)},
149+
{NULL, NULL}
150+
};
151+
152+
static void nn_(SpatialUpSamplingNearest_init)(lua_State *L)
153+
{
154+
luaT_pushmetatable(L, torch_Tensor);
155+
luaT_registeratname(L, nn_(SpatialUpSamplingNearest__), "nn");
156+
lua_pop(L,1);
157+
}
158+
159+
#endif

‎init.c

+5
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,9 @@
107107
#include "generic/L1Cost.c"
108108
#include "THGenerateFloatTypes.h"
109109

110+
#include "generic/SpatialUpSamplingNearest.c"
111+
#include "THGenerateFloatTypes.h"
112+
110113
LUA_EXTERNC DLL_EXPORT int luaopen_libnn(lua_State *L);
111114

112115
int luaopen_libnn(lua_State *L)
@@ -149,6 +152,7 @@ int luaopen_libnn(lua_State *L)
149152
nn_FloatMultiMarginCriterion_init(L);
150153
nn_FloatMultiLabelMarginCriterion_init(L);
151154
nn_FloatL1Cost_init(L);
155+
nn_FloatSpatialUpSamplingNearest_init(L);
152156

153157
nn_DoubleMin_init(L);
154158
nn_DoubleMax_init(L);
@@ -184,6 +188,7 @@ int luaopen_libnn(lua_State *L)
184188
nn_DoubleMultiMarginCriterion_init(L);
185189
nn_DoubleMultiLabelMarginCriterion_init(L);
186190
nn_DoubleL1Cost_init(L);
191+
nn_DoubleSpatialUpSamplingNearest_init(L);
187192

188193
return 1;
189194
}

‎init.lua

+1
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,7 @@ include('SpatialSubtractiveNormalization.lua')
7676
include('SpatialDivisiveNormalization.lua')
7777
include('SpatialContrastiveNormalization.lua')
7878
include('SpatialZeroPadding.lua')
79+
include('SpatialUpSamplingNearest.lua')
7980

8081
include('VolumetricConvolution.lua')
8182
include('VolumetricMaxPooling.lua')

‎test/test.lua

+64
Original file line numberDiff line numberDiff line change
@@ -1886,6 +1886,70 @@ function nntest.View()
18861886
"Error in minibatch nElement")
18871887
end
18881888

1889+
-- Define a test for SpatialUpSamplingCuda
1890+
function nntest.SpatialUpSamplingNearest()
1891+
local scale = torch.random(2,4)
1892+
for dim = 3,4 do
1893+
local m = nn.SpatialUpSamplingNearest(scale)
1894+
1895+
-- Create a randomly sized dimD vector
1896+
local shape = {}
1897+
for i = 1, dim do
1898+
table.insert(shape, torch.random(2, 2+dim-1))
1899+
end
1900+
1901+
-- Check that the gradient is correct by using finite elements
1902+
local input = torch.Tensor(unpack(shape)):zero()
1903+
1904+
local err = jac.testJacobian(m, input)
1905+
mytester:assertlt(err, precision, ' error on state ')
1906+
1907+
local ferr, berr = jac.testIO(m, input)
1908+
mytester:asserteq(ferr, 0, torch.typename(m)..' - i/o forward err ')
1909+
mytester:asserteq(berr, 0, torch.typename(m)..' - i/o backward err ')
1910+
1911+
-- Also check that the forward prop is correct.
1912+
input = torch.rand(unpack(shape))
1913+
local output = m:forward(input)
1914+
1915+
local feat
1916+
local nfeats
1917+
if input:dim() == 3 then
1918+
nfeats = shape[1]
1919+
feat = {0}
1920+
else
1921+
feat = {0, 0}
1922+
nfeats = shape[1] * shape[2]
1923+
end
1924+
feat[#feat+1] = 0 -- ydim
1925+
feat[#feat+1] = 0 -- xdim
1926+
local xdim = input:dim()
1927+
local ydim = input:dim()-1
1928+
local err = 0
1929+
for f = 1, nfeats do
1930+
if input:dim() == 4 then
1931+
feat[1] = math.floor((f-1) / shape[1]) + 1
1932+
feat[2] = math.mod((f-1), shape[2]) + 1
1933+
else
1934+
feat[1] = f
1935+
end
1936+
for y = 1, input:size(ydim) * scale do
1937+
for x = 1, input:size(xdim) * scale do
1938+
feat[ydim] = y
1939+
feat[xdim] = x
1940+
local oval = output[feat]
1941+
feat[ydim] = math.floor((y-1)/scale)+1
1942+
feat[xdim] = math.floor((x-1)/scale)+1
1943+
local ival = input[feat]
1944+
err = math.max(err, math.abs(oval-ival))
1945+
end
1946+
end
1947+
end
1948+
1949+
mytester:assertlt(err, precision, ' fprop is incorrect ')
1950+
end
1951+
end
1952+
18891953
mytester:add(nntest)
18901954

18911955
if not nn then

0 commit comments

Comments
 (0)
Please sign in to comment.