Skip to content
This repository was archived by the owner on Jan 3, 2023. It is now read-only.

Commit c1f3bee

Browse files
ayzhuangsilee2
authored andcommitted
Migrate PR#3638 from master. (#3735)
* Check size requirement before creating scratchpad. * Check max scratchpad size before allocating scratchpad_buffer. * Add the same checks for CODEGEN. * Fix unused-parameter warning. * Fix a typo. * Address PR feedback. * Fix a bug. * Fix a typo.
1 parent be738d0 commit c1f3bee

33 files changed

+998
-555
lines changed

src/ngraph/runtime/cpu/builder/add.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ namespace ngraph
4040

4141
auto& mkldnn_emitter = external_function->get_mkldnn_emitter();
4242
auto sum_pd = mkldnn_emitter->get_elementwise_add_desc(node);
43-
QUERY_SCRATCHPAD(sum, sum_pd);
43+
size_t scratchpad_size = QUERY_SCRATCHPAD(sum, sum_pd);
4444

4545
// Add needs 4 primitives: input0, input1, result, and sum.
4646
size_t add_index = mkldnn_emitter->reserve_primitive_space(4);
@@ -55,6 +55,7 @@ namespace ngraph
5555
auto functor = [&,
5656
sum_pd,
5757
add_index,
58+
scratchpad_size,
5859
arg0_buffer_index,
5960
arg1_buffer_index,
6061
out_buffer_index](CPURuntimeContext* ctx,
@@ -76,7 +77,7 @@ namespace ngraph
7677
ctx, deps[2], ctx->buffer_data[out_buffer_index]);
7778

7879
cpu::mkldnn_utils::mkldnn_invoke_primitive(
79-
ctx, add_index, deps, cpu::mkldnn_utils::OpType::ADD);
80+
ctx, add_index, deps, cpu::mkldnn_utils::OpType::ADD, scratchpad_size);
8081
};
8182
functors.emplace_back(functor);
8283
}

src/ngraph/runtime/cpu/builder/avg_pool.cpp

Lines changed: 37 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -55,32 +55,40 @@ namespace ngraph
5555
auto avg_pool_desc =
5656
mkldnn_emitter->get_avg_pooling_forward_desc<ngraph::op::AvgPool>(node,
5757
false);
58-
QUERY_SCRATCHPAD(pooling_forward, avg_pool_desc);
58+
size_t scratchpad_size = QUERY_SCRATCHPAD(pooling_forward, avg_pool_desc);
5959

6060
// AvgPool needs 3 primitives: input, result, and pooling_forward.
6161
size_t avg_pool_index = mkldnn_emitter->reserve_primitive_space(3);
6262
auto& deps = mkldnn_emitter->get_primitive_deps(avg_pool_index);
6363

64-
auto functor =
65-
[&, avg_pool_desc, avg_pool_index, arg0_buffer_index, out_buffer_index](
66-
CPURuntimeContext* ctx, CPUExecutionContext* ectx) {
67-
if (ctx->first_iteration)
68-
{
69-
mkldnn_emitter->build_pooling_forward(ctx->mkldnn_memories,
70-
ctx->mkldnn_primitives,
71-
ctx->mkldnn_scratchpad_mds,
72-
avg_pool_desc,
73-
deps,
74-
avg_pool_index);
75-
}
76-
cpu::mkldnn_utils::set_memory_ptr(
77-
ctx, deps[0], ctx->buffer_data[arg0_buffer_index]);
78-
cpu::mkldnn_utils::set_memory_ptr(
79-
ctx, deps[1], ctx->buffer_data[out_buffer_index]);
80-
81-
cpu::mkldnn_utils::mkldnn_invoke_primitive(
82-
ctx, avg_pool_index, deps, cpu::mkldnn_utils::OpType::AVGPOOL);
83-
};
64+
auto functor = [&,
65+
avg_pool_desc,
66+
avg_pool_index,
67+
scratchpad_size,
68+
arg0_buffer_index,
69+
out_buffer_index](CPURuntimeContext* ctx,
70+
CPUExecutionContext* ectx) {
71+
if (ctx->first_iteration)
72+
{
73+
mkldnn_emitter->build_pooling_forward(ctx->mkldnn_memories,
74+
ctx->mkldnn_primitives,
75+
ctx->mkldnn_scratchpad_mds,
76+
avg_pool_desc,
77+
deps,
78+
avg_pool_index);
79+
}
80+
cpu::mkldnn_utils::set_memory_ptr(
81+
ctx, deps[0], ctx->buffer_data[arg0_buffer_index]);
82+
cpu::mkldnn_utils::set_memory_ptr(
83+
ctx, deps[1], ctx->buffer_data[out_buffer_index]);
84+
85+
cpu::mkldnn_utils::mkldnn_invoke_primitive(
86+
ctx,
87+
avg_pool_index,
88+
deps,
89+
cpu::mkldnn_utils::OpType::AVGPOOL,
90+
scratchpad_size);
91+
};
8492
functors.emplace_back(functor);
8593
}
8694
else
@@ -145,7 +153,8 @@ namespace ngraph
145153
auto avg_pool_desc =
146154
mkldnn_emitter->get_avg_pooling_backward_desc<ngraph::op::AvgPoolBackprop>(
147155
node);
148-
QUERY_SCRATCHPAD_2ARGS(avg_pooling_backward, avg_pool_fwd_desc, avg_pool_desc);
156+
size_t scratchpad_size = QUERY_SCRATCHPAD_2ARGS(
157+
avg_pooling_backward, avg_pool_fwd_desc, avg_pool_desc);
149158

150159
// AvgPoolBackprop needs 3 primitives: input, result, and pooling_backward.
151160
size_t avg_pool_index = mkldnn_emitter->reserve_primitive_space(3);
@@ -155,6 +164,7 @@ namespace ngraph
155164
avg_pool_desc,
156165
avg_pool_fwd_desc,
157166
avg_pool_index,
167+
scratchpad_size,
158168
delta_buffer_index,
159169
out_buffer_index](CPURuntimeContext* ctx,
160170
CPUExecutionContext* ectx) {
@@ -174,7 +184,11 @@ namespace ngraph
174184
ctx, deps[1], ctx->buffer_data[out_buffer_index]);
175185

176186
cpu::mkldnn_utils::mkldnn_invoke_primitive(
177-
ctx, avg_pool_index, deps, cpu::mkldnn_utils::OpType::AVGPOOLBACKPROP);
187+
ctx,
188+
avg_pool_index,
189+
deps,
190+
cpu::mkldnn_utils::OpType::AVGPOOLBACKPROP,
191+
scratchpad_size);
178192
};
179193
functors.emplace_back(functor);
180194
}

src/ngraph/runtime/cpu/builder/batch_norm.cpp

Lines changed: 24 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,8 @@ namespace ngraph
8484
auto& mkldnn_emitter = external_function->get_mkldnn_emitter();
8585
auto batchnorm_desc =
8686
mkldnn_emitter->get_batchnorm_forward_desc<OP>(node, true);
87-
QUERY_SCRATCHPAD_2ARGS(batchnorm_forward, batchnorm_desc, ops);
87+
size_t scratchpad_size =
88+
QUERY_SCRATCHPAD_2ARGS(batchnorm_forward, batchnorm_desc, ops);
8889

8990
auto weights_shape = Shape{2, args[0].get_size()};
9091
auto weights_desc = mkldnn_emitter->build_memory_descriptor(
@@ -101,6 +102,7 @@ namespace ngraph
101102
training,
102103
ops,
103104
batchnorm_index,
105+
scratchpad_size,
104106
stacked_weights,
105107
weight_sizes,
106108
arg0_buffer_index,
@@ -140,7 +142,11 @@ namespace ngraph
140142
ctx, deps[4], ctx->buffer_data[out2_buffer_index]);
141143

142144
cpu::mkldnn_utils::mkldnn_invoke_primitive(
143-
ctx, batchnorm_index, deps, cpu::mkldnn_utils::OpType::BATCHNORM3ARGS);
145+
ctx,
146+
batchnorm_index,
147+
deps,
148+
cpu::mkldnn_utils::OpType::BATCHNORM3ARGS,
149+
scratchpad_size);
144150
};
145151
functors.emplace_back(functor);
146152
}
@@ -155,7 +161,8 @@ namespace ngraph
155161
auto batchnorm_desc =
156162
mkldnn_emitter->get_batchnorm_forward_desc<OP>(node, false);
157163

158-
QUERY_SCRATCHPAD_2ARGS(batchnorm_forward, batchnorm_desc, ops);
164+
size_t scratchpad_size =
165+
QUERY_SCRATCHPAD_2ARGS(batchnorm_forward, batchnorm_desc, ops);
159166

160167
auto weights_shape = Shape{2, args[0].get_size()};
161168
auto weights_desc = mkldnn_emitter->build_memory_descriptor(
@@ -172,6 +179,7 @@ namespace ngraph
172179
training,
173180
ops,
174181
batchnorm_index,
182+
scratchpad_size,
175183
stacked_weights,
176184
weight_sizes,
177185
arg0_buffer_index,
@@ -211,7 +219,11 @@ namespace ngraph
211219
ctx, deps[4], ctx->buffer_data[out0_buffer_index]);
212220

213221
cpu::mkldnn_utils::mkldnn_invoke_primitive(
214-
ctx, batchnorm_index, deps, cpu::mkldnn_utils::OpType::BATCHNORM5ARGS);
222+
ctx,
223+
batchnorm_index,
224+
deps,
225+
cpu::mkldnn_utils::OpType::BATCHNORM5ARGS,
226+
scratchpad_size);
215227
};
216228
functors.emplace_back(functor);
217229
}
@@ -444,14 +456,16 @@ namespace ngraph
444456
static_cast<const ngraph::op::BatchNormTrainingBackprop*>(node);
445457
auto eps = batchnorm->get_eps_value();
446458
(void)eps; // Use depends on mkl-dnn version
447-
QUERY_SCRATCHPAD_3ARGS(batchnorm_backward, batchnorm_desc, input_desc, eps);
459+
size_t scratchpad_size =
460+
QUERY_SCRATCHPAD_3ARGS(batchnorm_backward, batchnorm_desc, input_desc, eps);
448461

449462
auto functor = [&,
450463
batchnorm_desc,
451464
input_desc,
452465
weights_desc,
453466
dweights_desc,
454467
batchnorm_index,
468+
scratchpad_size,
455469
stacked_weights,
456470
stacked_dweights,
457471
weight_sizes,
@@ -499,7 +513,11 @@ namespace ngraph
499513
cpu::mkldnn_utils::set_memory_ptr(ctx, deps[6], stacked_dweights.get());
500514

501515
cpu::mkldnn_utils::mkldnn_invoke_primitive(
502-
ctx, batchnorm_index, deps, cpu::mkldnn_utils::OpType::BATCHNORMBACKPROP);
516+
ctx,
517+
batchnorm_index,
518+
deps,
519+
cpu::mkldnn_utils::OpType::BATCHNORMBACKPROP,
520+
scratchpad_size);
503521

504522
memcpy(ctx->buffer_data[out1_buffer_index],
505523
stacked_dweights.get(),

src/ngraph/runtime/cpu/builder/bounded_relu.cpp

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ namespace ngraph
4444
{
4545
auto& mkldnn_emitter = external_function->get_mkldnn_emitter();
4646
auto bounded_relu_desc = mkldnn_emitter->get_bounded_relu_desc(node);
47-
QUERY_SCRATCHPAD(eltwise_forward, bounded_relu_desc);
47+
size_t scratchpad_size = QUERY_SCRATCHPAD(eltwise_forward, bounded_relu_desc);
4848

4949
// BoundedRelu needs 3 primitives: input, result, and eltwise_forward.
5050
auto bounded_relu_index = mkldnn_emitter->reserve_primitive_space(3);
@@ -53,6 +53,7 @@ namespace ngraph
5353
auto functor = [&,
5454
bounded_relu_desc,
5555
bounded_relu_index,
56+
scratchpad_size,
5657
input_buffer_index,
5758
out_buffer_index](CPURuntimeContext* ctx,
5859
CPUExecutionContext* ectx) {
@@ -71,7 +72,11 @@ namespace ngraph
7172
ctx, deps[1], ctx->buffer_data[out_buffer_index]);
7273

7374
cpu::mkldnn_utils::mkldnn_invoke_primitive(
74-
ctx, bounded_relu_index, deps, cpu::mkldnn_utils::OpType::BOUNDEDRELU);
75+
ctx,
76+
bounded_relu_index,
77+
deps,
78+
cpu::mkldnn_utils::OpType::BOUNDEDRELU,
79+
scratchpad_size);
7580
};
7681
functors.emplace_back(functor);
7782
}

src/ngraph/runtime/cpu/builder/concat.cpp

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,7 @@ namespace ngraph
101101
auto& mkldnn_emitter = external_function->get_mkldnn_emitter();
102102
auto concat_pd =
103103
mkldnn_emitter->get_concat_desc<ngraph::op::Concat>(node, nargs);
104-
QUERY_SCRATCHPAD(concat, concat_pd);
104+
size_t scratchpad_size = QUERY_SCRATCHPAD(concat, concat_pd);
105105

106106
std::vector<mkldnn::memory::desc> inputs_data_desc;
107107
for (size_t i = 0; i < nargs; i++)
@@ -115,6 +115,7 @@ namespace ngraph
115115

116116
auto functor = [&,
117117
concat_pd,
118+
scratchpad_size,
118119
inputs_data_desc,
119120
arg_buffer_indices,
120121
nargs,
@@ -140,7 +141,11 @@ namespace ngraph
140141
ctx, deps[nargs], ctx->buffer_data[out_buffer_index]);
141142

142143
cpu::mkldnn_utils::mkldnn_invoke_primitive(
143-
ctx, concat_index, deps, cpu::mkldnn_utils::OpType::CONCAT);
144+
ctx,
145+
concat_index,
146+
deps,
147+
cpu::mkldnn_utils::OpType::CONCAT,
148+
scratchpad_size);
144149
};
145150

146151
functors.emplace_back(functor);

src/ngraph/runtime/cpu/builder/convert_layout.cpp

Lines changed: 32 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,8 @@ namespace ngraph
4343
auto input_desc = mkldnn_utils::get_input_mkldnn_md(node, 0);
4444
auto result_desc = mkldnn_utils::get_output_mkldnn_md(node, 0);
4545

46+
size_t scratchpad_size = 0;
47+
4648
#if MKLDNN_VERSION_MAJOR < 1
4749
if (input_desc.data.format == mkldnn_nchw &&
4850
result_desc.data.format == mkldnn_goihw)
@@ -131,32 +133,41 @@ namespace ngraph
131133
mkldnn::memory::format_tag::goihw);
132134
}
133135

134-
mkldnn_emitter->query_scratchpad_reorder(input_desc, result_desc);
136+
scratchpad_size = mkldnn_emitter->query_scratchpad_reorder(input_desc, result_desc);
135137
#endif
136138
// ConvertLayout needs 3 primitives: input, result, and reorder.
137139
size_t reorder_index = mkldnn_emitter->reserve_primitive_space(3);
138140
auto& deps = mkldnn_emitter->get_primitive_deps(reorder_index);
139-
auto functor =
140-
[&, input_desc, result_desc, reorder_index, arg_buffer_index, out_buffer_index](
141-
CPURuntimeContext* ctx, CPUExecutionContext* ectx) {
142-
if (ctx->first_iteration)
143-
{
144-
mkldnn_emitter->build_reorder(ctx->mkldnn_memories,
145-
ctx->mkldnn_primitives,
146-
ctx->mkldnn_scratchpad_mds,
147-
input_desc,
148-
result_desc,
149-
deps,
150-
reorder_index);
151-
}
152-
cpu::mkldnn_utils::set_memory_ptr(
153-
ctx, deps[0], ctx->buffer_data[arg_buffer_index]);
154-
cpu::mkldnn_utils::set_memory_ptr(
155-
ctx, deps[1], ctx->buffer_data[out_buffer_index]);
141+
auto functor = [&,
142+
input_desc,
143+
result_desc,
144+
reorder_index,
145+
scratchpad_size,
146+
arg_buffer_index,
147+
out_buffer_index](CPURuntimeContext* ctx,
148+
CPUExecutionContext* ectx) {
149+
if (ctx->first_iteration)
150+
{
151+
mkldnn_emitter->build_reorder(ctx->mkldnn_memories,
152+
ctx->mkldnn_primitives,
153+
ctx->mkldnn_scratchpad_mds,
154+
input_desc,
155+
result_desc,
156+
deps,
157+
reorder_index);
158+
}
159+
cpu::mkldnn_utils::set_memory_ptr(
160+
ctx, deps[0], ctx->buffer_data[arg_buffer_index]);
161+
cpu::mkldnn_utils::set_memory_ptr(
162+
ctx, deps[1], ctx->buffer_data[out_buffer_index]);
156163

157-
cpu::mkldnn_utils::mkldnn_invoke_primitive(
158-
ctx, reorder_index, deps, cpu::mkldnn_utils::OpType::CONVERTLAYOUT);
159-
};
164+
cpu::mkldnn_utils::mkldnn_invoke_primitive(
165+
ctx,
166+
reorder_index,
167+
deps,
168+
cpu::mkldnn_utils::OpType::CONVERTLAYOUT,
169+
scratchpad_size);
170+
};
160171
functors.emplace_back(functor);
161172
}
162173

0 commit comments

Comments
 (0)