Skip to content

Commit c9430d3

Browse files
gemm_convolution: memory access fix
1 parent 936ce25 commit c9430d3

File tree

1 file changed

+86
-9
lines changed

1 file changed

+86
-9
lines changed

src/cpu/x64/jit_gemm_convolution_utils.cpp

Lines changed: 86 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,8 @@ struct jit_pp_kernel_t : pp_kernel_t, public jit_generator {
8484

8585
private:
8686
void generate() override;
87+
void copy_elems(const Xbyak::Reg64 &dst, const Xbyak::Reg64 &src, const Xbyak::Reg64 &size, const int elemSize);
88+
void foreach (const Xbyak::Reg64 &idx, size_t step, const Xbyak::Reg64 &end, std::function<void(const Xbyak::Reg64 &)> && fn);
8789

8890
struct ker_args_t {
8991
float *dst;
@@ -142,6 +144,47 @@ struct jit_pp_kernel_t : pp_kernel_t, public jit_generator {
142144
Vmm vreg_bias(int idx) { return Vmm(idx_vreg_bias(idx)); };
143145
};
144146

147+
template <cpu_isa_t isa>
148+
void jit_pp_kernel_t<isa>::foreach (const Xbyak::Reg64 &idx, size_t step,
149+
const Xbyak::Reg64 &end, std::function<void(const Xbyak::Reg64&)> && fn)
150+
{
151+
Xbyak::Label loop, exit;
152+
153+
L(loop);
154+
cmp(idx, end);
155+
jge(exit);
156+
157+
fn(idx);
158+
159+
add(idx, step);
160+
jmp(loop);
161+
L(exit);
162+
}
163+
164+
template <cpu_isa_t isa>
165+
void jit_pp_kernel_t<isa>::copy_elems(const Xbyak::Reg64 &dst,
166+
const Xbyak::Reg64& src, const Xbyak::Reg64& size, const int elemSize) {
167+
push(rsi);
168+
push(r13);
169+
170+
xor_(rsi, rsi);
171+
172+
if (elemSize == 1) {
173+
foreach(rsi, 1, size, [&, this](const Xbyak::Reg64& idx) {
174+
mov(r13b, byte[src + idx * elemSize]);
175+
mov(byte[dst + idx * elemSize], r13b);
176+
});
177+
} else if (elemSize == 4) {
178+
foreach(rsi, 1, size, [&, this](const Xbyak::Reg64& idx) {
179+
mov(r13d, dword[src + idx * elemSize]);
180+
mov(dword[dst + idx * elemSize], r13d);
181+
});
182+
}
183+
184+
pop(r13);
185+
pop(rsi);
186+
}
187+
145188
template <cpu_isa_t isa>
146189
void jit_pp_kernel_t<isa>::generate() {
147190
using namespace Xbyak;
@@ -161,7 +204,18 @@ void jit_pp_kernel_t<isa>::generate() {
161204
mov(reg_table, l_table);
162205
}
163206

164-
auto apply_post_ops = [&]() {
207+
auto store_to_stack = [&](const Reg64 &from, const Reg64 &size) {
208+
sub(rsp, vlen * sizeof(float));
209+
mov(r8, rsp);
210+
copy_elems(r8, from, size, sizeof(float));
211+
};
212+
213+
auto load_from_stack = [&](const Vmm &to) {
214+
uni_vmovups(to, ptr[rsp]);
215+
add(rsp, vlen * sizeof(float));
216+
};
217+
218+
auto apply_post_ops = [&](bool apply_mask) {
165219
int eltwise_inj_idx = 0;
166220
int depthwise_inj_idx = 0;
167221
auto vreg_dst_ = vreg_dst(0);
@@ -176,8 +230,20 @@ void jit_pp_kernel_t<isa>::generate() {
176230
mov(reg_d_bias, reinterpret_cast<size_t>(post_op.depthwise.biases_data));
177231
lea(reg_d_weights, ptr[reg_d_weights + reg_oc_offset * sizeof(float)]);
178232
lea(reg_d_bias, ptr[reg_d_bias + reg_oc_offset * sizeof(float)]);
233+
if (apply_mask) {
234+
store_to_stack(reg_d_weights, reg_tmp);
235+
mov(reg_d_weights, rsp);
236+
237+
if (post_op.depthwise.alg == dnnl_depthwise_scale_shift) {
238+
store_to_stack(reg_d_bias, reg_tmp);
239+
mov(reg_d_bias, rsp);
240+
}
241+
}
179242
jit_depthwise_injectors_[depthwise_inj_idx]->compute_vector_range(vreg_dst_.getIdx(), vreg_dst_.getIdx() + 1,
180243
reg_d_weights, reg_d_bias, true);
244+
if (apply_mask) {
245+
add(rsp, (post_op.depthwise.alg == dnnl_depthwise_scale_shift ? 2 : 1) * vlen * sizeof(float));
246+
}
181247
depthwise_inj_idx++;
182248
} else if (post_op.is_quantization()) {
183249
bool do_dequantization = post_op.quantization.alg == alg_kind::quantization_quantize_dequantize;
@@ -243,6 +309,10 @@ void jit_pp_kernel_t<isa>::generate() {
243309
// Load accumulated value, convert to float, apply bias (if any), scaling,
244310
// and eltwise (if any); then convert to destination type and store
245311
auto compute = [&](bool apply_mask) {
312+
if (apply_mask) {
313+
push(r8);
314+
}
315+
246316
auto dst_addr = ptr[reg_dst];
247317
auto vreg_dst_ = vreg_dst(0);
248318
if (isa == avx512_common) {
@@ -251,11 +321,8 @@ void jit_pp_kernel_t<isa>::generate() {
251321
uni_vmovups(vreg_dst_, dst_addr);
252322
} else {
253323
if (apply_mask) {
254-
if (isa != sse41) {
255-
uni_vblendvps(vreg_dst_, vreg_zero, dst_addr, vreg_mask);
256-
} else {
257-
uni_vmovups(vreg_dst_, dst_addr);
258-
}
324+
store_to_stack(reg_dst, reg_tmp);
325+
load_from_stack(vreg_dst_);
259326
} else {
260327
uni_vmovups(vreg_dst_, dst_addr);
261328
}
@@ -270,7 +337,7 @@ void jit_pp_kernel_t<isa>::generate() {
270337
uni_vaddps(vreg_dst_, vreg_dst_, vreg_bias_);
271338
}
272339

273-
apply_post_ops();
340+
apply_post_ops(apply_mask);
274341

275342
if (isa == avx512_common) {
276343
uni_vmovups(dst_addr, vreg_dst_);
@@ -279,13 +346,20 @@ void jit_pp_kernel_t<isa>::generate() {
279346
if (isa != sse41) {
280347
vmaskmovps(dst_addr, vreg_mask, vreg_dst_);
281348
} else {
282-
lea(reg_ptr_maskmovdqu_dst, dst_addr);
283-
maskmovdqu(vreg_dst_, vreg_mask);
349+
sub(rsp, vlen * sizeof(float));
350+
mov(r8, rsp);
351+
uni_vmovups(ptr[r8], vreg_dst_);
352+
copy_elems(reg_dst, r8, reg_tmp, sizeof(float));
353+
add(rsp, vlen * sizeof(float));
284354
}
285355
} else {
286356
uni_vmovups(dst_addr, vreg_dst_);
287357
}
288358
}
359+
360+
if (apply_mask) {
361+
pop(r8);
362+
}
289363
};
290364

291365
Label loop_end;
@@ -303,6 +377,9 @@ void jit_pp_kernel_t<isa>::generate() {
303377
cmp(reg_len, vlen);
304378
jge(loop, T_NEAR);
305379
}
380+
381+
cmp(reg_tmp, 0);
382+
je(loop_end, T_NEAR);
306383

307384
L(loop_tail);
308385
mov(reg_tmp, reg_len); // reg_tmp is rcx, and we need cl for the shift

0 commit comments

Comments
 (0)