@@ -84,6 +84,8 @@ struct jit_pp_kernel_t : pp_kernel_t, public jit_generator {
84
84
85
85
private:
86
86
void generate () override ;
87
+ void copy_elems (const Xbyak::Reg64 &dst, const Xbyak::Reg64 &src, const Xbyak::Reg64 &size, const int elemSize);
88
+ void foreach (const Xbyak::Reg64 &idx, size_t step, const Xbyak::Reg64 &end, std::function<void (const Xbyak::Reg64 &)> && fn);
87
89
88
90
struct ker_args_t {
89
91
float *dst;
@@ -142,6 +144,47 @@ struct jit_pp_kernel_t : pp_kernel_t, public jit_generator {
142
144
Vmm vreg_bias (int idx) { return Vmm (idx_vreg_bias (idx)); };
143
145
};
144
146
147
+ template <cpu_isa_t isa>
148
+ void jit_pp_kernel_t <isa>::foreach (const Xbyak::Reg64 &idx, size_t step,
149
+ const Xbyak::Reg64 &end, std::function<void (const Xbyak::Reg64&)> && fn)
150
+ {
151
+ Xbyak::Label loop, exit;
152
+
153
+ L (loop);
154
+ cmp (idx, end);
155
+ jge (exit);
156
+
157
+ fn (idx);
158
+
159
+ add (idx, step);
160
+ jmp (loop);
161
+ L (exit);
162
+ }
163
+
164
+ template <cpu_isa_t isa>
165
+ void jit_pp_kernel_t <isa>::copy_elems(const Xbyak::Reg64 &dst,
166
+ const Xbyak::Reg64& src, const Xbyak::Reg64& size, const int elemSize) {
167
+ push (rsi);
168
+ push (r13);
169
+
170
+ xor_ (rsi, rsi);
171
+
172
+ if (elemSize == 1 ) {
173
+ foreach (rsi, 1 , size, [&, this ](const Xbyak::Reg64& idx) {
174
+ mov (r13b, byte[src + idx * elemSize]);
175
+ mov (byte[dst + idx * elemSize], r13b);
176
+ });
177
+ } else if (elemSize == 4 ) {
178
+ foreach (rsi, 1 , size, [&, this ](const Xbyak::Reg64& idx) {
179
+ mov (r13d, dword[src + idx * elemSize]);
180
+ mov (dword[dst + idx * elemSize], r13d);
181
+ });
182
+ }
183
+
184
+ pop (r13);
185
+ pop (rsi);
186
+ }
187
+
145
188
template <cpu_isa_t isa>
146
189
void jit_pp_kernel_t <isa>::generate() {
147
190
using namespace Xbyak ;
@@ -161,7 +204,18 @@ void jit_pp_kernel_t<isa>::generate() {
161
204
mov (reg_table, l_table);
162
205
}
163
206
164
- auto apply_post_ops = [&]() {
207
+ auto store_to_stack = [&](const Reg64 &from, const Reg64 &size) {
208
+ sub (rsp, vlen * sizeof (float ));
209
+ mov (r8, rsp);
210
+ copy_elems (r8, from, size, sizeof (float ));
211
+ };
212
+
213
+ auto load_from_stack = [&](const Vmm &to) {
214
+ uni_vmovups (to, ptr[rsp]);
215
+ add (rsp, vlen * sizeof (float ));
216
+ };
217
+
218
+ auto apply_post_ops = [&](bool apply_mask) {
165
219
int eltwise_inj_idx = 0 ;
166
220
int depthwise_inj_idx = 0 ;
167
221
auto vreg_dst_ = vreg_dst (0 );
@@ -176,8 +230,20 @@ void jit_pp_kernel_t<isa>::generate() {
176
230
mov (reg_d_bias, reinterpret_cast <size_t >(post_op.depthwise .biases_data ));
177
231
lea (reg_d_weights, ptr[reg_d_weights + reg_oc_offset * sizeof (float )]);
178
232
lea (reg_d_bias, ptr[reg_d_bias + reg_oc_offset * sizeof (float )]);
233
+ if (apply_mask) {
234
+ store_to_stack (reg_d_weights, reg_tmp);
235
+ mov (reg_d_weights, rsp);
236
+
237
+ if (post_op.depthwise .alg == dnnl_depthwise_scale_shift) {
238
+ store_to_stack (reg_d_bias, reg_tmp);
239
+ mov (reg_d_bias, rsp);
240
+ }
241
+ }
179
242
jit_depthwise_injectors_[depthwise_inj_idx]->compute_vector_range (vreg_dst_.getIdx (), vreg_dst_.getIdx () + 1 ,
180
243
reg_d_weights, reg_d_bias, true );
244
+ if (apply_mask) {
245
+ add (rsp, (post_op.depthwise .alg == dnnl_depthwise_scale_shift ? 2 : 1 ) * vlen * sizeof (float ));
246
+ }
181
247
depthwise_inj_idx++;
182
248
} else if (post_op.is_quantization ()) {
183
249
bool do_dequantization = post_op.quantization .alg == alg_kind::quantization_quantize_dequantize;
@@ -243,6 +309,10 @@ void jit_pp_kernel_t<isa>::generate() {
243
309
// Load accumulated value, convert to float, apply bias (if any), scaling,
244
310
// and eltwise (if any); then convert to destination type and store
245
311
auto compute = [&](bool apply_mask) {
312
+ if (apply_mask) {
313
+ push (r8);
314
+ }
315
+
246
316
auto dst_addr = ptr[reg_dst];
247
317
auto vreg_dst_ = vreg_dst (0 );
248
318
if (isa == avx512_common) {
@@ -251,11 +321,8 @@ void jit_pp_kernel_t<isa>::generate() {
251
321
uni_vmovups (vreg_dst_, dst_addr);
252
322
} else {
253
323
if (apply_mask) {
254
- if (isa != sse41) {
255
- uni_vblendvps (vreg_dst_, vreg_zero, dst_addr, vreg_mask);
256
- } else {
257
- uni_vmovups (vreg_dst_, dst_addr);
258
- }
324
+ store_to_stack (reg_dst, reg_tmp);
325
+ load_from_stack (vreg_dst_);
259
326
} else {
260
327
uni_vmovups (vreg_dst_, dst_addr);
261
328
}
@@ -270,7 +337,7 @@ void jit_pp_kernel_t<isa>::generate() {
270
337
uni_vaddps (vreg_dst_, vreg_dst_, vreg_bias_);
271
338
}
272
339
273
- apply_post_ops ();
340
+ apply_post_ops (apply_mask );
274
341
275
342
if (isa == avx512_common) {
276
343
uni_vmovups (dst_addr, vreg_dst_);
@@ -279,13 +346,20 @@ void jit_pp_kernel_t<isa>::generate() {
279
346
if (isa != sse41) {
280
347
vmaskmovps (dst_addr, vreg_mask, vreg_dst_);
281
348
} else {
282
- lea (reg_ptr_maskmovdqu_dst, dst_addr);
283
- maskmovdqu (vreg_dst_, vreg_mask);
349
+ sub (rsp, vlen * sizeof (float ));
350
+ mov (r8, rsp);
351
+ uni_vmovups (ptr[r8], vreg_dst_);
352
+ copy_elems (reg_dst, r8, reg_tmp, sizeof (float ));
353
+ add (rsp, vlen * sizeof (float ));
284
354
}
285
355
} else {
286
356
uni_vmovups (dst_addr, vreg_dst_);
287
357
}
288
358
}
359
+
360
+ if (apply_mask) {
361
+ pop (r8);
362
+ }
289
363
};
290
364
291
365
Label loop_end;
@@ -303,6 +377,9 @@ void jit_pp_kernel_t<isa>::generate() {
303
377
cmp (reg_len, vlen);
304
378
jge (loop, T_NEAR);
305
379
}
380
+
381
+ cmp (reg_tmp, 0 );
382
+ je (loop_end, T_NEAR);
306
383
307
384
L (loop_tail);
308
385
mov (reg_tmp, reg_len); // reg_tmp is rcx, and we need cl for the shift
0 commit comments