Skip to content

Commit 564d372

Browse files
committedNov 17, 2014
fix 1D case for Mean/Sum/Max/Min
1 parent bc684d8 commit 564d372

File tree

5 files changed

+107
-39
lines changed

5 files changed

+107
-39
lines changed
 

‎Mean.lua

+11-3
Original file line numberDiff line numberDiff line change
@@ -8,15 +8,23 @@ end
88

99
function Mean:updateOutput(input)
1010
self.output:mean(input, self.dimension)
11-
self.output = self.output:select(self.dimension, 1)
11+
if self.output:nDimension() > 1 then
12+
self.output = self.output:select(self.dimension, 1)
13+
end
1214
return self.output
1315
end
1416

1517
function Mean:updateGradInput(input, gradOutput)
1618
local size = gradOutput:size():totable()
1719
local stride = gradOutput:stride():totable()
18-
table.insert(size, self.dimension, input:size(self.dimension))
19-
table.insert(stride, self.dimension, 0)
20+
21+
if input:nDimension() > 1 then
22+
table.insert(size, self.dimension, input:size(self.dimension))
23+
table.insert(stride, self.dimension, 0)
24+
else
25+
size[1] = input:size(1)
26+
stride[1] = 0
27+
end
2028

2129
self.gradInput:resizeAs(gradOutput):copy(gradOutput)
2230
self.gradInput:mul(1/input:size(self.dimension))

‎Sum.lua

+12-4
Original file line numberDiff line numberDiff line change
@@ -11,20 +11,28 @@ function Sum:updateOutput(input)
1111
self.output = input.new()
1212
end
1313
self.output:sum(input, self.dimension)
14-
self.output = self.output:select(self.dimension, 1)
14+
if self.output:nDimension() > 1 then
15+
self.output = self.output:select(self.dimension, 1)
16+
end
1517
return self.output
1618
end
1719

1820
function Sum:updateGradInput(input, gradOutput)
1921
local size = gradOutput:size():totable()
2022
local stride = gradOutput:stride():totable()
21-
table.insert(size, self.dimension, input:size(self.dimension))
22-
table.insert(stride, self.dimension, 0)
23+
24+
if input:nDimension() > 1 then
25+
table.insert(size, self.dimension, input:size(self.dimension))
26+
table.insert(stride, self.dimension, 0)
27+
else
28+
size[1] = input:size(1)
29+
stride[1] = 0
30+
end
2331

2432
self.gradInput:set(gradOutput:storage(),
2533
1,
2634
torch.LongStorage(size),
2735
torch.LongStorage(stride))
28-
36+
2937
return self.gradInput
3038
end

‎generic/Max.c

+24-16
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,8 @@ static int nn_(Max_updateOutput)(lua_State *L)
3636
*indices_data = theIndex+1;
3737
*output_data = theMax;)
3838

39-
THTensor_(select)(output, NULL, dimension, 0);
39+
if(output->nDimension > 1)
40+
THTensor_(select)(output, NULL, dimension, 0);
4041

4142
return 1;
4243
}
@@ -56,25 +57,32 @@ static int nn_(Max_updateGradInput)(lua_State *L)
5657
THTensor_(resizeAs)(gradInput, input);
5758
THTensor_(zero)(gradInput);
5859

59-
dim = THLongStorage_newWithSize(gradOutput->nDimension+1);
60-
str = THLongStorage_newWithSize(gradOutput->nDimension+1);
61-
for(i = 0, j = 0; j < gradOutput->nDimension+1; j++)
60+
if(input->nDimension > 1)
6261
{
63-
if(j == dimension)
62+
dim = THLongStorage_newWithSize(gradOutput->nDimension+1);
63+
str = THLongStorage_newWithSize(gradOutput->nDimension+1);
64+
for(i = 0, j = 0; j < gradOutput->nDimension+1; j++)
6465
{
65-
dim->data[j] = input->size[dimension];
66-
str->data[j] = 0;
67-
continue;
66+
if(j == dimension)
67+
{
68+
dim->data[j] = input->size[dimension];
69+
str->data[j] = 0;
70+
continue;
71+
}
72+
73+
dim->data[j] = gradOutput->size[i];
74+
str->data[j] = gradOutput->stride[i];
75+
i++;
6876
}
69-
70-
dim->data[j] = gradOutput->size[i];
71-
str->data[j] = gradOutput->stride[i];
72-
i++;
77+
gradOutputPlusOneDim = THTensor_(newWithStorage)(gradOutput->storage, gradOutput->storageOffset, dim, str);
78+
THLongStorage_free(dim);
79+
THLongStorage_free(str);
80+
}
81+
else
82+
{
83+
THTensor_(retain)(gradOutput);
84+
gradOutputPlusOneDim = gradOutput;
7385
}
74-
75-
gradOutputPlusOneDim = THTensor_(newWithStorage)(gradOutput->storage, gradOutput->storageOffset, dim, str);
76-
THLongStorage_free(dim);
77-
THLongStorage_free(str);
7886

7987
TH_TENSOR_DIM_APPLY3(real, gradInput, real, gradOutputPlusOneDim, real, indices, dimension,
8088
gradInput_data[ ((long)(*indices_data)-1)*gradInput_stride ] = *gradOutputPlusOneDim_data;)

‎generic/Min.c

+24-16
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,8 @@ static int nn_(Min_updateOutput)(lua_State *L)
3636
*indices_data = theIndex+1;
3737
*output_data = theMin;)
3838

39-
THTensor_(select)(output, NULL, dimension, 0);
39+
if(output->nDimension > 1)
40+
THTensor_(select)(output, NULL, dimension, 0);
4041

4142
return 1;
4243
}
@@ -56,25 +57,32 @@ static int nn_(Min_updateGradInput)(lua_State *L)
5657
THTensor_(resizeAs)(gradInput, input);
5758
THTensor_(zero)(gradInput);
5859

59-
dim = THLongStorage_newWithSize(gradOutput->nDimension+1);
60-
str = THLongStorage_newWithSize(gradOutput->nDimension+1);
61-
for(i = 0, j = 0; j < gradOutput->nDimension+1; j++)
60+
if(input->nDimension > 1)
6261
{
63-
if(j == dimension)
62+
dim = THLongStorage_newWithSize(gradOutput->nDimension+1);
63+
str = THLongStorage_newWithSize(gradOutput->nDimension+1);
64+
for(i = 0, j = 0; j < gradOutput->nDimension+1; j++)
6465
{
65-
dim->data[j] = input->size[dimension];
66-
str->data[j] = 0;
67-
continue;
66+
if(j == dimension)
67+
{
68+
dim->data[j] = input->size[dimension];
69+
str->data[j] = 0;
70+
continue;
71+
}
72+
73+
dim->data[j] = gradOutput->size[i];
74+
str->data[j] = gradOutput->stride[i];
75+
i++;
6876
}
69-
70-
dim->data[j] = gradOutput->size[i];
71-
str->data[j] = gradOutput->stride[i];
72-
i++;
77+
gradOutputPlusOneDim = THTensor_(newWithStorage)(gradOutput->storage, gradOutput->storageOffset, dim, str);
78+
THLongStorage_free(dim);
79+
THLongStorage_free(str);
80+
}
81+
else
82+
{
83+
THTensor_(retain)(gradOutput);
84+
gradOutputPlusOneDim = gradOutput;
7385
}
74-
75-
gradOutputPlusOneDim = THTensor_(newWithStorage)(gradOutput->storage, gradOutput->storageOffset, dim, str);
76-
THLongStorage_free(dim);
77-
THLongStorage_free(str);
7886

7987
TH_TENSOR_DIM_APPLY3(real, gradInput, real, gradOutputPlusOneDim, real, indices, dimension,
8088
gradInput_data[ ((long)(*indices_data)-1)*gradInput_stride ] = *gradOutputPlusOneDim_data;)

‎test/test.lua

+36
Original file line numberDiff line numberDiff line change
@@ -595,6 +595,15 @@ end
595595
-- end
596596

597597
function nntest.Max()
598+
-- 1D
599+
local ini = math.random(3,7)
600+
local input = torch.Tensor(ini):zero()
601+
local module = nn.Max(1)
602+
603+
local err = jac.testJacobian(module,input)
604+
mytester:assertlt(err,precision, 'error on state ')
605+
606+
-- 3D
598607
local ini = math.random(3,5)
599608
local inj = math.random(3,5)
600609
local ink = math.random(3,5)
@@ -610,6 +619,15 @@ function nntest.Max()
610619
end
611620

612621
function nntest.Min()
622+
-- 1D
623+
local ini = math.random(3,7)
624+
local input = torch.Tensor(ini):zero()
625+
local module = nn.Min(1)
626+
627+
local err = jac.testJacobian(module,input)
628+
mytester:assertlt(err,precision, 'error on state ')
629+
630+
-- 3D
613631
local ini = math.random(3,5)
614632
local inj = math.random(3,5)
615633
local ink = math.random(3,5)
@@ -625,6 +643,15 @@ function nntest.Min()
625643
end
626644

627645
function nntest.Mean()
646+
-- 1D
647+
local ini = math.random(3,7)
648+
local input = torch.Tensor(ini):zero()
649+
local module = nn.Mean(1)
650+
651+
local err = jac.testJacobian(module,input)
652+
mytester:assertlt(err,precision, 'error on state ')
653+
654+
-- 3D
628655
local ini = math.random(3,5)
629656
local inj = math.random(3,5)
630657
local ink = math.random(3,5)
@@ -1423,6 +1450,15 @@ function nntest.SpatialLPPooling()
14231450
end
14241451

14251452
function nntest.Sum()
1453+
-- 1D
1454+
local ini = math.random(3,7)
1455+
local input = torch.Tensor(ini):zero()
1456+
local module = nn.Sum(1)
1457+
1458+
local err = jac.testJacobian(module,input)
1459+
mytester:assertlt(err,precision, 'error on state ')
1460+
1461+
-- 3D
14261462
local ini = math.random(3,5)
14271463
local inj = math.random(3,5)
14281464
local ink = math.random(3,5)

0 commit comments

Comments
 (0)
Please sign in to comment.