2
2
#define TH_GENERIC_FILE "generic/SparseLinear.c"
3
3
#else
4
4
5
+ static int nn_ (checkInput )(THTensor * t ) {
6
+ return t -> nDimension == 2 && t -> size [1 ] == 2 ;
7
+ }
8
+
9
+ static int nn_ (checkSize2D )(THTensor * t , long size0 , long size1 ) {
10
+ return t -> nDimension == 2 && t -> size [0 ] == size0 && t -> size [1 ] == size1 ;
11
+ }
12
+
13
+ static int nn_ (checkSize1D )(THTensor * t , long size0 ) {
14
+ return t -> nDimension == 1 && t -> size [0 ] == size0 ;
15
+ }
16
+
5
17
static int nn_ (SparseLinear_updateOutput )(lua_State * L )
6
18
{
7
19
long i ;
8
20
THTensor * input = luaT_checkudata (L , 2 , torch_Tensor );
9
21
THTensor * weight = luaT_getfieldcheckudata (L , 1 , "weight" , torch_Tensor );
10
22
THTensor * bias = luaT_getfieldcheckudata (L , 1 , "bias" , torch_Tensor );
11
23
THTensor * output = luaT_getfieldcheckudata (L , 1 , "output" , torch_Tensor );
12
- long dim = weight -> size [1 ]; /* number of weights.. */
24
+
25
+ long outDim = weight -> size [0 ];
26
+ long inDim = weight -> size [1 ];
27
+
28
+ luaL_argcheck (L , nn_ (checkInput )(input ), 2 , "input size must be nnz x 2" );
29
+ luaL_argcheck (L , nn_ (checkSize1D )(output , outDim ), 1 , "output size wrong" );
30
+ luaL_argcheck (L , nn_ (checkSize1D )(bias , outDim ), 1 , "bias size wrong" );
31
+
32
+ lua_getfield (L , 1 , "shardBuffer" );
33
+ if (!lua_isnil (L , -1 )) {
34
+ THTensor * buffer =
35
+ luaT_getfieldcheckudata (L , 1 , "shardBuffer" , torch_Tensor );
36
+ long num_shards = buffer -> size [1 ];
37
+ luaL_argcheck (L ,
38
+ buffer -> nDimension == 2 && buffer -> size [0 ] == outDim &&
39
+ num_shards > 0 ,
40
+ 1 ,
41
+ "shardBuffer size wrong" );
42
+
43
+ THTensor_ (zero )(buffer );
44
+ #pragma omp parallel for private(i) schedule(static) num_threads(num_shards)
45
+ for (i = 0 ; i < input -> size [0 ]; i ++ ) {
46
+ int shardId = omp_get_thread_num ();
47
+ long offset = (long )(THTensor_ (get2d )(input , i , 0 )) - 1 ;
48
+
49
+ if (offset >= 0 && offset < inDim ) {
50
+ THBlas_ (axpy )(outDim ,
51
+ THTensor_ (get2d )(input , i , 1 ),
52
+ THTensor_ (data )(weight ) + offset * weight -> stride [1 ],
53
+ weight -> stride [0 ],
54
+ THTensor_ (data )(buffer ) + shardId * buffer -> stride [1 ],
55
+ buffer -> stride [0 ]);
56
+ } else {
57
+ luaL_error (L , "index out of bound. updateOutput: \
58
+ %ld not between 1 and %ld" , offset + 1 , inDim );
59
+ }
60
+ }
61
+
62
+ THTensor_ (sum )(output , buffer , 1 );
63
+ THTensor_ (cadd )(output , bias , 1.0 , output );
64
+
65
+ lua_getfield (L , 1 , "output" );
66
+ return 1 ;
67
+ }
13
68
14
69
THTensor_ (copy )(output , bias );
15
70
for (i = 0 ; i < input -> size [0 ]; i ++ )
16
71
{
17
72
long offset = (long )(THTensor_ (get2d )(input , i , 0 )) - 1 ;
18
- if (offset >= 0 && offset < dim ) /* make sure indices are in bounds.. */
73
+ if (offset >= 0 && offset < inDim ) /* make sure indices are in bounds.. */
19
74
{
20
75
real val = THTensor_ (get2d )(input , i , 1 );
21
- THBlas_ (axpy )(output -> size [0 ],
22
- val ,
76
+ THBlas_ (axpy )(output -> size [0 ],
77
+ val ,
23
78
THTensor_ (data )(weight )+ offset * weight -> stride [1 ],
24
- weight -> stride [0 ],
25
- THTensor_ (data )(output ),
79
+ weight -> stride [0 ],
80
+ THTensor_ (data )(output ),
26
81
output -> stride [0 ]);
27
82
}
28
83
else {
29
- printf ( "\nupdateOutput: %ld not between 1 and %ld\n" , offset + 1 , dim );
30
- luaL_error ( L , "index out of bound" );
84
+ luaL_error ( L , "index out of bound. updateOutput: \
85
+ %ld not between 1 and %ld" , offset + 1 , inDim );
31
86
}
32
87
}
88
+
89
+ lua_getfield (L , 1 , "output" );
33
90
return 1 ;
34
91
}
35
92
@@ -42,39 +99,47 @@ static int nn_(SparseLinear_accGradParameters)(lua_State *L)
42
99
THTensor * weight = luaT_getfieldcheckudata (L , 1 , "weight" , torch_Tensor );
43
100
THTensor * gradBias = luaT_getfieldcheckudata (L , 1 , "gradBias" , torch_Tensor );
44
101
THTensor * gradWeight = luaT_getfieldcheckudata (L , 1 , "gradWeight" , torch_Tensor );
45
- THTensor * lastInput = luaT_getfieldcheckudata (L , 1 , "lastInput" , torch_Tensor );
46
102
real weightDecay = luaT_getfieldchecknumber (L , 1 , "weightDecay" );
47
- long dim = gradWeight -> size [1 ]; /* number of weights.. */
48
103
49
- for (i = 0 ; i < input -> size [0 ]; i ++ )
104
+ long nnz = input -> size [0 ];
105
+ long outDim = weight -> size [0 ];
106
+ long inDim = weight -> size [1 ];
107
+
108
+ luaL_argcheck (L , nn_ (checkInput )(input ), 2 , "input size must be nnz x 2" );
109
+ luaL_argcheck (
110
+ L , nn_ (checkSize1D )(gradOutput , outDim ), 3 , "gradOutput size wrong" );
111
+ luaL_argcheck (
112
+ L , nn_ (checkSize2D )(gradWeight , outDim , inDim ), 1 , "gradWeight size wrong" );
113
+ luaL_argcheck (
114
+ L , nn_ (checkSize1D )(gradBias , outDim ), 1 , "gradBias size wrong" );
115
+
116
+ #pragma omp parallel for private(i) schedule(static) if(outDim * nnz > 100000)
117
+ for (i = 0 ; i < nnz ; i ++ )
50
118
{
51
119
long offset = (long )(THTensor_ (get2d )(input , i , 0 )) - 1 ;
52
120
53
- if (offset >= 0 && offset < dim ) /* make sure indices are in bounds.. */
121
+ if (offset >= 0 && offset < inDim ) /* make sure indices are in bounds.. */
54
122
{
55
123
real val = scale * THTensor_ (get2d )(input , i , 1 );
56
-
57
- THBlas_ (axpy )(gradOutput -> size [ 0 ],
58
- val ,
59
- THTensor_ (data )(gradOutput ),
60
- gradOutput -> stride [0 ],
61
- THTensor_ (data )(gradWeight )+ offset * gradWeight -> stride [1 ],
124
+
125
+ THBlas_ (axpy )(outDim ,
126
+ val ,
127
+ THTensor_ (data )(gradOutput ),
128
+ gradOutput -> stride [0 ],
129
+ THTensor_ (data )(gradWeight )+ offset * gradWeight -> stride [1 ],
62
130
gradWeight -> stride [0 ]);
63
131
}
64
132
else {
65
- printf ( "\naccGradParameters: %ld not between 1 and %ld\n" , offset + 1 , dim );
66
- luaL_error ( L , "index out of bound" );
133
+ luaL_error ( L , "index out of bound. accGradParameters: \
134
+ %ld not between 1 and %ld" , offset + 1 , inDim );
67
135
}
68
136
}
69
-
70
- THTensor_ (cadd )(gradBias , gradBias , scale , gradOutput );
71
-
137
+
138
+ THTensor_ (cadd )(gradBias , gradBias , scale , gradOutput );
139
+
72
140
if (weightDecay != 0 )
73
141
THTensor_ (cadd )(gradWeight , gradWeight , weightDecay , weight );
74
-
75
- THTensor_ (resizeAs )(lastInput , input );
76
- THTensor_ (copy )(lastInput , input );
77
-
142
+
78
143
return 0 ;
79
144
}
80
145
@@ -85,37 +150,137 @@ int nn_(SparseLinear_updateParameters)(lua_State *L)
85
150
THTensor * weight = luaT_getfieldcheckudata (L , 1 , "weight" , torch_Tensor );
86
151
THTensor * bias = luaT_getfieldcheckudata (L , 1 , "bias" , torch_Tensor );
87
152
THTensor * gradBias = luaT_getfieldcheckudata (L , 1 , "gradBias" , torch_Tensor );
88
- THTensor * gradWeight = luaT_getfieldcheckudata (L , 1 , "gradWeight" , torch_Tensor );
89
- THTensor * lastInput = luaT_getfieldcheckudata (L , 1 , "lastInput" , torch_Tensor );
90
-
91
- long dim = weight -> size [1 ]; /* number of weights.. */
153
+ THTensor * gradWeight = luaT_getfieldcheckudata (
154
+ L , 1 , "gradWeight" , torch_Tensor );
155
+ THTensor * lastInput = luaT_getfieldcheckudata (
156
+ L , 1 , "lastInput" , torch_Tensor );
157
+
158
+ long nnz = lastInput -> size [0 ];
159
+ long outDim = weight -> size [0 ];
160
+ long inDim = weight -> size [1 ];
161
+
162
+ luaL_argcheck (
163
+ L , nn_ (checkSize2D )(gradWeight , outDim , inDim ), 1 , "gradWeight size wrong" );
164
+ luaL_argcheck (
165
+ L , nn_ (checkSize1D )(bias , outDim ), 1 , "bias size wrong" );
166
+ luaL_argcheck (
167
+ L , nn_ (checkSize1D )(gradBias , outDim ), 1 , "gradBias size wrong" );
168
+
92
169
THTensor_ (cadd )(bias , bias , - learningRate , gradBias );
93
-
94
- for (i = 0 ; i < lastInput -> size [0 ]; i ++ )
170
+
171
+ #pragma omp parallel for private(i) schedule(static) if(outDim * nnz > 50000)
172
+ for (i = 0 ; i < nnz ; i ++ )
95
173
{
96
174
long offset = (long )(THTensor_ (get2d )(lastInput , i , 0 )) - 1 ;
97
-
98
- if (offset >= 0 && offset < dim ) /* make sure indices are in bounds.. */
175
+
176
+ if (offset >= 0 && offset < inDim ) /* make sure indices are in bounds.. */
99
177
{
100
- THBlas_ (axpy )(bias -> size [0 ],
101
- - learningRate ,
102
- THTensor_ (data )(gradWeight )+ offset * gradWeight -> stride [1 ],
103
- gradWeight -> stride [0 ],
104
- THTensor_ (data )(weight )+ offset * weight -> stride [1 ],
178
+ real * pGradWeight =
179
+ THTensor_ (data )(gradWeight )+ offset * gradWeight -> stride [1 ];
180
+ THBlas_ (axpy )(outDim ,
181
+ - learningRate ,
182
+ pGradWeight ,
183
+ gradWeight -> stride [0 ],
184
+ THTensor_ (data )(weight )+ offset * weight -> stride [1 ],
105
185
weight -> stride [0 ]);
106
186
}
107
187
else {
108
- printf ("\nupdateParameters: %ld not between 1 and %ld\n" , offset + 1 , dim );
109
- luaL_error (L , "index out of bound" );
188
+ luaL_error (L , "index out of bound. updateParameters: \
189
+ %ld not between 1 and %ld" , offset + 1 , inDim );
190
+ }
191
+ }
192
+ return 0 ;
193
+ }
194
+
195
+ int nn_ (SparseLinear_zeroGradParameters )(lua_State * L )
196
+ {
197
+ long i ;
198
+ THTensor * gradBias = luaT_getfieldcheckudata (L , 1 , "gradBias" , torch_Tensor );
199
+ THTensor * gradWeight = luaT_getfieldcheckudata (
200
+ L , 1 , "gradWeight" , torch_Tensor );
201
+ THTensor * lastInput = luaT_getfieldcheckudata (
202
+ L , 1 , "lastInput" , torch_Tensor );
203
+
204
+ long nnz = lastInput -> size [0 ];
205
+ long outDim = gradWeight -> size [0 ];
206
+ long inDim = gradWeight -> size [1 ];
207
+
208
+ luaL_argcheck (
209
+ L , nn_ (checkSize1D )(gradBias , outDim ), 1 , "gradBias size wrong" );
210
+
211
+ THTensor_ (zero )(gradBias );
212
+ #pragma omp parallel for private(i) schedule(static) if(outDim * nnz > 50000)
213
+ for (i = 0 ; i < nnz ; i ++ )
214
+ {
215
+ long offset = (long )(THTensor_ (get2d )(lastInput , i , 0 )) - 1 ;
216
+
217
+ if (offset >= 0 && offset < inDim ) /* make sure indices are in bounds.. */
218
+ {
219
+ real * pGradWeight =
220
+ THTensor_ (data )(gradWeight )+ offset * gradWeight -> stride [1 ];
221
+ if (gradWeight -> stride [0 ] == 1 ) {
222
+ THVector_ (fill )(pGradWeight , 0 , outDim );
223
+ } else {
224
+ long j ;
225
+ for (j = 0 ; j < outDim ; ++ j ) {
226
+ pGradWeight [j * gradWeight -> stride [0 ]] = 0 ;
227
+ }
228
+ }
229
+ }
230
+ else {
231
+ luaL_error (L , "index out of bound. zeroGradParameters: \
232
+ %ld not between 1 and %ld" , offset + 1 , inDim );
110
233
}
111
234
}
112
235
return 0 ;
113
236
}
114
237
238
+ static int nn_ (SparseLinear_updateGradInput )(lua_State * L ) {
239
+ THTensor * weight = luaT_getfieldcheckudata (L , 1 , "weight" , torch_Tensor );
240
+ THTensor * gradInput =
241
+ luaT_getfieldcheckudata (L , 1 , "gradInput" , torch_Tensor );
242
+ THTensor * input = luaT_checkudata (L , 2 , torch_Tensor );
243
+ THTensor * gradOutput = luaT_checkudata (L , 3 , torch_Tensor );
244
+
245
+ long i ;
246
+ long nnz = input -> size [0 ];
247
+ long outDim = weight -> size [0 ];
248
+ long inDim = weight -> size [1 ];
249
+
250
+ luaL_argcheck (
251
+ L , nn_ (checkInput )(input ), 2 , "input must be an nnz x 2 tensor" );
252
+ luaL_argcheck (
253
+ L , nn_ (checkSize1D )(gradOutput , outDim ), 3 , "gradOutput size wrong" );
254
+
255
+ THTensor_ (resize2d )(gradInput , input -> size [0 ], input -> size [1 ]);
256
+
257
+ #pragma omp parallel for private(i) schedule(static) if(outDim * nnz > 100000)
258
+ for (i = 0 ; i < nnz ; ++ i ) {
259
+ long offset = (long )(THTensor_ (get2d )(input , i , 0 )) - 1 ;
260
+ THTensor_ (set2d )(gradInput , i , 0 , offset + 1 );
261
+
262
+ if (offset >= 0 && offset < inDim ) {
263
+ real val =
264
+ THBlas_ (dot )(outDim ,
265
+ THTensor_ (data )(gradOutput ),
266
+ gradOutput -> stride [0 ],
267
+ THTensor_ (data )(weight ) + offset * weight -> stride [1 ],
268
+ weight -> stride [0 ]);
269
+ THTensor_ (set2d )(gradInput , i , 1 , val );
270
+ } else {
271
+ luaL_error (L , "index out of bound. updateGradInput: \
272
+ %ld not between 1 and %ld" , offset + 1 , inDim );
273
+ }
274
+ }
275
+ return 0 ;
276
+ }
277
+
115
278
static const struct luaL_Reg nn_ (SparseLinear__ ) [] = {
116
279
{"SparseLinear_updateOutput" , nn_ (SparseLinear_updateOutput )},
117
280
{"SparseLinear_accGradParameters" , nn_ (SparseLinear_accGradParameters )},
118
281
{"SparseLinear_updateParameters" , nn_ (SparseLinear_updateParameters )},
282
+ {"SparseLinear_zeroGradParameters" , nn_ (SparseLinear_zeroGradParameters )},
283
+ {"SparseLinear_updateGradInput" , nn_ (SparseLinear_updateGradInput )},
119
284
{NULL , NULL }
120
285
};
121
286
0 commit comments