Skip to content

Commit a38407a

Browse files
committed
speedup and optimizations for SparseLinear
1 parent 2340b9c commit a38407a

File tree

2 files changed

+226
-60
lines changed

2 files changed

+226
-60
lines changed

SparseLinear.lua

Lines changed: 19 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,16 @@ function SparseLinear:__init(inputSize, outputSize)
44
parent.__init(self)
55

66
self.weightDecay = 0
7-
self.weight = torch.Tensor(outputSize, inputSize)
8-
self.bias = torch.Tensor(outputSize)
9-
self.gradWeight = torch.Tensor(outputSize, inputSize)
10-
self.gradBias = torch.Tensor(outputSize)
11-
self.lastInput = torch.Tensor()
7+
self.weight = torch.Tensor(outputSize, inputSize):zero()
8+
self.bias = torch.Tensor(outputSize):zero()
9+
self.gradWeight = torch.Tensor(outputSize, inputSize):zero()
10+
self.gradBias = torch.Tensor(outputSize):zero()
11+
self.lastInput = nil
12+
13+
if torch.getnumthreads() > 1 and outputSize >= 128 then
14+
self.shardBuffer = torch.Tensor(outputSize, torch.getnumthreads())
15+
end
16+
1217
-- state
1318
self.gradInput:resize(inputSize)
1419
self.output:resize(outputSize)
@@ -20,7 +25,7 @@ function SparseLinear:reset(stdv)
2025
if stdv then
2126
stdv = stdv * math.sqrt(3)
2227
else
23-
stdv = 1./math.sqrt(self.weight:size(1))
28+
stdv = 1./math.sqrt(self.weight:size(2))
2429
end
2530
if nn.oldSeed then
2631
for i=1,self.weight:size(1) do
@@ -40,22 +45,18 @@ function SparseLinear:updateOutput(input)
4045
end
4146

4247
function SparseLinear:accGradParameters(input, gradOutput, scale)
48+
if not self.lastInput then
49+
self.lastInput = input:clone()
50+
else
51+
self.lastInput:resizeAs(input):copy(input)
52+
end
53+
4354
return input.nn.SparseLinear_accGradParameters(self, input, gradOutput, scale)
4455
end
4556

4657
function SparseLinear:updateGradInput(input, gradOutput)
4758
if self.gradInput then
48-
self.gradInput:resize(input:size())
49-
self.gradInput:copy(input)
50-
local numNonzero = self.gradInput:size(1)
51-
for e=1,numNonzero do
52-
local g = 0
53-
local i = self.gradInput[{e,1}]
54-
for j=1,self.output:size(1) do
55-
g = g + self.weight[{j,i}] * gradOutput[j]
56-
end
57-
self.gradInput[{e,2}] = g
58-
end
59+
input.nn.SparseLinear_updateGradInput(self, input, gradOutput)
5960
return self.gradInput
6061
end
61-
end
62+
end

generic/SparseLinear.c

Lines changed: 207 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -2,34 +2,91 @@
22
#define TH_GENERIC_FILE "generic/SparseLinear.c"
33
#else
44

5+
static int nn_(checkInput)(THTensor* t) {
6+
return t->nDimension == 2 && t->size[1] == 2;
7+
}
8+
9+
static int nn_(checkSize2D)(THTensor* t, long size0, long size1) {
10+
return t->nDimension == 2 && t->size[0] == size0 && t->size[1] == size1;
11+
}
12+
13+
static int nn_(checkSize1D)(THTensor* t, long size0) {
14+
return t->nDimension == 1 && t->size[0] == size0;
15+
}
16+
517
static int nn_(SparseLinear_updateOutput)(lua_State *L)
618
{
719
long i;
820
THTensor * input = luaT_checkudata(L, 2, torch_Tensor);
921
THTensor * weight = luaT_getfieldcheckudata(L, 1, "weight", torch_Tensor);
1022
THTensor * bias = luaT_getfieldcheckudata(L, 1, "bias", torch_Tensor);
1123
THTensor * output = luaT_getfieldcheckudata(L, 1, "output", torch_Tensor);
12-
long dim = weight->size[1]; /* number of weights.. */
24+
25+
long outDim = weight->size[0];
26+
long inDim = weight->size[1];
27+
28+
luaL_argcheck(L, nn_(checkInput)(input), 2, "input size must be nnz x 2");
29+
luaL_argcheck(L, nn_(checkSize1D)(output, outDim), 1, "output size wrong");
30+
luaL_argcheck(L, nn_(checkSize1D)(bias, outDim), 1, "bias size wrong");
31+
32+
lua_getfield(L, 1, "shardBuffer");
33+
if (!lua_isnil(L, -1)) {
34+
THTensor *buffer =
35+
luaT_getfieldcheckudata(L, 1, "shardBuffer", torch_Tensor);
36+
long num_shards = buffer->size[1];
37+
luaL_argcheck(L,
38+
buffer->nDimension == 2 && buffer->size[0] == outDim &&
39+
num_shards > 0,
40+
1,
41+
"shardBuffer size wrong");
42+
43+
THTensor_(zero)(buffer);
44+
#pragma omp parallel for private(i) schedule(static) num_threads(num_shards)
45+
for (i = 0; i < input->size[0]; i++) {
46+
int shardId = omp_get_thread_num();
47+
long offset = (long)(THTensor_(get2d)(input, i, 0)) - 1;
48+
49+
if (offset >= 0 && offset < inDim) {
50+
THBlas_(axpy)(outDim,
51+
THTensor_(get2d)(input, i, 1),
52+
THTensor_(data)(weight) + offset * weight->stride[1],
53+
weight->stride[0],
54+
THTensor_(data)(buffer) + shardId * buffer->stride[1],
55+
buffer->stride[0]);
56+
} else {
57+
luaL_error(L, "index out of bound. updateOutput: \
58+
%ld not between 1 and %ld", offset + 1, inDim);
59+
}
60+
}
61+
62+
THTensor_(sum)(output, buffer, 1);
63+
THTensor_(cadd)(output, bias, 1.0, output);
64+
65+
lua_getfield(L, 1, "output");
66+
return 1;
67+
}
1368

1469
THTensor_(copy)(output, bias);
1570
for(i = 0; i < input->size[0]; i++)
1671
{
1772
long offset = (long)(THTensor_(get2d)(input, i, 0)) - 1;
18-
if(offset >= 0 && offset < dim) /* make sure indices are in bounds.. */
73+
if(offset >= 0 && offset < inDim) /* make sure indices are in bounds.. */
1974
{
2075
real val = THTensor_(get2d)(input, i, 1);
21-
THBlas_(axpy)(output->size[0],
22-
val,
76+
THBlas_(axpy)(output->size[0],
77+
val,
2378
THTensor_(data)(weight)+offset*weight->stride[1],
24-
weight->stride[0],
25-
THTensor_(data)(output),
79+
weight->stride[0],
80+
THTensor_(data)(output),
2681
output->stride[0]);
2782
}
2883
else {
29-
printf("\nupdateOutput: %ld not between 1 and %ld\n", offset+1, dim);
30-
luaL_error(L, "index out of bound");
84+
luaL_error(L, "index out of bound. updateOutput: \
85+
%ld not between 1 and %ld", offset + 1, inDim);
3186
}
3287
}
88+
89+
lua_getfield(L, 1, "output");
3390
return 1;
3491
}
3592

@@ -42,39 +99,47 @@ static int nn_(SparseLinear_accGradParameters)(lua_State *L)
4299
THTensor * weight = luaT_getfieldcheckudata(L, 1, "weight", torch_Tensor);
43100
THTensor * gradBias = luaT_getfieldcheckudata(L, 1, "gradBias", torch_Tensor);
44101
THTensor * gradWeight = luaT_getfieldcheckudata(L, 1, "gradWeight", torch_Tensor);
45-
THTensor * lastInput = luaT_getfieldcheckudata(L, 1, "lastInput", torch_Tensor);
46102
real weightDecay = luaT_getfieldchecknumber(L, 1, "weightDecay");
47-
long dim = gradWeight->size[1]; /* number of weights.. */
48103

49-
for(i = 0; i < input->size[0]; i++)
104+
long nnz = input->size[0];
105+
long outDim = weight->size[0];
106+
long inDim = weight->size[1];
107+
108+
luaL_argcheck(L, nn_(checkInput)(input), 2, "input size must be nnz x 2");
109+
luaL_argcheck(
110+
L, nn_(checkSize1D)(gradOutput, outDim), 3, "gradOutput size wrong");
111+
luaL_argcheck(
112+
L, nn_(checkSize2D)(gradWeight, outDim, inDim), 1, "gradWeight size wrong");
113+
luaL_argcheck(
114+
L, nn_(checkSize1D)(gradBias, outDim), 1, "gradBias size wrong");
115+
116+
#pragma omp parallel for private(i) schedule(static) if(outDim * nnz > 100000)
117+
for(i = 0; i < nnz; i++)
50118
{
51119
long offset = (long)(THTensor_(get2d)(input, i, 0)) - 1;
52120

53-
if(offset >= 0 && offset < dim) /* make sure indices are in bounds.. */
121+
if(offset >= 0 && offset < inDim) /* make sure indices are in bounds.. */
54122
{
55123
real val = scale*THTensor_(get2d)(input, i, 1);
56-
57-
THBlas_(axpy)(gradOutput->size[0],
58-
val,
59-
THTensor_(data)(gradOutput),
60-
gradOutput->stride[0],
61-
THTensor_(data)(gradWeight)+offset*gradWeight->stride[1],
124+
125+
THBlas_(axpy)(outDim,
126+
val,
127+
THTensor_(data)(gradOutput),
128+
gradOutput->stride[0],
129+
THTensor_(data)(gradWeight)+offset*gradWeight->stride[1],
62130
gradWeight->stride[0]);
63131
}
64132
else {
65-
printf("\naccGradParameters: %ld not between 1 and %ld\n", offset+1, dim);
66-
luaL_error(L, "index out of bound");
133+
luaL_error(L, "index out of bound. accGradParameters: \
134+
%ld not between 1 and %ld", offset + 1, inDim);
67135
}
68136
}
69-
70-
THTensor_(cadd)(gradBias, gradBias, scale, gradOutput);
71-
137+
138+
THTensor_(cadd)(gradBias, gradBias, scale, gradOutput);
139+
72140
if(weightDecay != 0)
73141
THTensor_(cadd)(gradWeight, gradWeight, weightDecay, weight);
74-
75-
THTensor_(resizeAs)(lastInput, input);
76-
THTensor_(copy)(lastInput, input);
77-
142+
78143
return 0;
79144
}
80145

@@ -85,37 +150,137 @@ int nn_(SparseLinear_updateParameters)(lua_State *L)
85150
THTensor * weight = luaT_getfieldcheckudata(L, 1, "weight", torch_Tensor);
86151
THTensor * bias = luaT_getfieldcheckudata(L, 1, "bias", torch_Tensor);
87152
THTensor * gradBias = luaT_getfieldcheckudata(L, 1, "gradBias", torch_Tensor);
88-
THTensor * gradWeight = luaT_getfieldcheckudata(L, 1, "gradWeight", torch_Tensor);
89-
THTensor * lastInput = luaT_getfieldcheckudata(L, 1, "lastInput", torch_Tensor);
90-
91-
long dim = weight->size[1]; /* number of weights.. */
153+
THTensor * gradWeight = luaT_getfieldcheckudata(
154+
L, 1, "gradWeight", torch_Tensor);
155+
THTensor * lastInput = luaT_getfieldcheckudata(
156+
L, 1, "lastInput", torch_Tensor);
157+
158+
long nnz = lastInput->size[0];
159+
long outDim = weight->size[0];
160+
long inDim = weight->size[1];
161+
162+
luaL_argcheck(
163+
L, nn_(checkSize2D)(gradWeight, outDim, inDim), 1, "gradWeight size wrong");
164+
luaL_argcheck(
165+
L, nn_(checkSize1D)(bias, outDim), 1, "bias size wrong");
166+
luaL_argcheck(
167+
L, nn_(checkSize1D)(gradBias, outDim), 1, "gradBias size wrong");
168+
92169
THTensor_(cadd)(bias, bias, -learningRate, gradBias);
93-
94-
for(i = 0; i < lastInput->size[0]; i++)
170+
171+
#pragma omp parallel for private(i) schedule(static) if(outDim * nnz > 50000)
172+
for(i = 0; i < nnz; i++)
95173
{
96174
long offset = (long)(THTensor_(get2d)(lastInput, i, 0)) - 1;
97-
98-
if(offset >= 0 && offset < dim) /* make sure indices are in bounds.. */
175+
176+
if(offset >= 0 && offset < inDim) /* make sure indices are in bounds.. */
99177
{
100-
THBlas_(axpy)(bias->size[0],
101-
-learningRate,
102-
THTensor_(data)(gradWeight)+offset*gradWeight->stride[1],
103-
gradWeight->stride[0],
104-
THTensor_(data)(weight)+offset*weight->stride[1],
178+
real* pGradWeight =
179+
THTensor_(data)(gradWeight)+offset*gradWeight->stride[1];
180+
THBlas_(axpy)(outDim,
181+
-learningRate,
182+
pGradWeight,
183+
gradWeight->stride[0],
184+
THTensor_(data)(weight)+offset*weight->stride[1],
105185
weight->stride[0]);
106186
}
107187
else {
108-
printf("\nupdateParameters: %ld not between 1 and %ld\n", offset+1, dim);
109-
luaL_error(L, "index out of bound");
188+
luaL_error(L, "index out of bound. updateParameters: \
189+
%ld not between 1 and %ld", offset + 1, inDim);
190+
}
191+
}
192+
return 0;
193+
}
194+
195+
int nn_(SparseLinear_zeroGradParameters)(lua_State *L)
196+
{
197+
long i;
198+
THTensor * gradBias = luaT_getfieldcheckudata(L, 1, "gradBias", torch_Tensor);
199+
THTensor * gradWeight = luaT_getfieldcheckudata(
200+
L, 1, "gradWeight", torch_Tensor);
201+
THTensor * lastInput = luaT_getfieldcheckudata(
202+
L, 1, "lastInput", torch_Tensor);
203+
204+
long nnz = lastInput->size[0];
205+
long outDim = gradWeight->size[0];
206+
long inDim = gradWeight->size[1];
207+
208+
luaL_argcheck(
209+
L, nn_(checkSize1D)(gradBias, outDim), 1, "gradBias size wrong");
210+
211+
THTensor_(zero)(gradBias);
212+
#pragma omp parallel for private(i) schedule(static) if(outDim * nnz > 50000)
213+
for(i = 0; i < nnz; i++)
214+
{
215+
long offset = (long)(THTensor_(get2d)(lastInput, i, 0)) - 1;
216+
217+
if(offset >= 0 && offset < inDim) /* make sure indices are in bounds.. */
218+
{
219+
real* pGradWeight =
220+
THTensor_(data)(gradWeight)+offset*gradWeight->stride[1];
221+
if(gradWeight->stride[0] == 1) {
222+
THVector_(fill)(pGradWeight, 0, outDim);
223+
} else {
224+
long j;
225+
for(j = 0; j < outDim; ++j) {
226+
pGradWeight[j * gradWeight->stride[0]] = 0;
227+
}
228+
}
229+
}
230+
else {
231+
luaL_error(L, "index out of bound. zeroGradParameters: \
232+
%ld not between 1 and %ld", offset + 1, inDim);
110233
}
111234
}
112235
return 0;
113236
}
114237

238+
static int nn_(SparseLinear_updateGradInput)(lua_State *L) {
239+
THTensor *weight = luaT_getfieldcheckudata(L, 1, "weight", torch_Tensor);
240+
THTensor *gradInput =
241+
luaT_getfieldcheckudata(L, 1, "gradInput", torch_Tensor);
242+
THTensor *input = luaT_checkudata(L, 2, torch_Tensor);
243+
THTensor *gradOutput = luaT_checkudata(L, 3, torch_Tensor);
244+
245+
long i;
246+
long nnz = input->size[0];
247+
long outDim = weight->size[0];
248+
long inDim = weight->size[1];
249+
250+
luaL_argcheck(
251+
L, nn_(checkInput)(input), 2, "input must be an nnz x 2 tensor");
252+
luaL_argcheck(
253+
L, nn_(checkSize1D)(gradOutput, outDim), 3, "gradOutput size wrong");
254+
255+
THTensor_(resize2d)(gradInput, input->size[0], input->size[1]);
256+
257+
#pragma omp parallel for private(i) schedule(static) if(outDim * nnz > 100000)
258+
for (i = 0; i < nnz; ++i) {
259+
long offset = (long)(THTensor_(get2d)(input, i, 0)) - 1;
260+
THTensor_(set2d)(gradInput, i, 0, offset + 1);
261+
262+
if (offset >= 0 && offset < inDim) {
263+
real val =
264+
THBlas_(dot)(outDim,
265+
THTensor_(data)(gradOutput),
266+
gradOutput->stride[0],
267+
THTensor_(data)(weight) + offset * weight->stride[1],
268+
weight->stride[0]);
269+
THTensor_(set2d)(gradInput, i, 1, val);
270+
} else {
271+
luaL_error(L, "index out of bound. updateGradInput: \
272+
%ld not between 1 and %ld", offset + 1, inDim);
273+
}
274+
}
275+
return 0;
276+
}
277+
115278
static const struct luaL_Reg nn_(SparseLinear__) [] = {
116279
{"SparseLinear_updateOutput", nn_(SparseLinear_updateOutput)},
117280
{"SparseLinear_accGradParameters", nn_(SparseLinear_accGradParameters)},
118281
{"SparseLinear_updateParameters", nn_(SparseLinear_updateParameters)},
282+
{"SparseLinear_zeroGradParameters", nn_(SparseLinear_zeroGradParameters)},
283+
{"SparseLinear_updateGradInput", nn_(SparseLinear_updateGradInput)},
119284
{NULL, NULL}
120285
};
121286

0 commit comments

Comments
 (0)