@@ -55,32 +55,40 @@ namespace ngraph
55
55
auto avg_pool_desc =
56
56
mkldnn_emitter->get_avg_pooling_forward_desc <ngraph::op::AvgPool>(node,
57
57
false );
58
- QUERY_SCRATCHPAD (pooling_forward, avg_pool_desc);
58
+ size_t scratchpad_size = QUERY_SCRATCHPAD (pooling_forward, avg_pool_desc);
59
59
60
60
// AvgPool needs 3 primitives: input, result, and pooling_forward.
61
61
size_t avg_pool_index = mkldnn_emitter->reserve_primitive_space (3 );
62
62
auto & deps = mkldnn_emitter->get_primitive_deps (avg_pool_index);
63
63
64
- auto functor =
65
- [&, avg_pool_desc, avg_pool_index, arg0_buffer_index, out_buffer_index](
66
- CPURuntimeContext* ctx, CPUExecutionContext* ectx) {
67
- if (ctx->first_iteration )
68
- {
69
- mkldnn_emitter->build_pooling_forward (ctx->mkldnn_memories ,
70
- ctx->mkldnn_primitives ,
71
- ctx->mkldnn_scratchpad_mds ,
72
- avg_pool_desc,
73
- deps,
74
- avg_pool_index);
75
- }
76
- cpu::mkldnn_utils::set_memory_ptr (
77
- ctx, deps[0 ], ctx->buffer_data [arg0_buffer_index]);
78
- cpu::mkldnn_utils::set_memory_ptr (
79
- ctx, deps[1 ], ctx->buffer_data [out_buffer_index]);
80
-
81
- cpu::mkldnn_utils::mkldnn_invoke_primitive (
82
- ctx, avg_pool_index, deps, cpu::mkldnn_utils::OpType::AVGPOOL);
83
- };
64
+ auto functor = [&,
65
+ avg_pool_desc,
66
+ avg_pool_index,
67
+ scratchpad_size,
68
+ arg0_buffer_index,
69
+ out_buffer_index](CPURuntimeContext* ctx,
70
+ CPUExecutionContext* ectx) {
71
+ if (ctx->first_iteration )
72
+ {
73
+ mkldnn_emitter->build_pooling_forward (ctx->mkldnn_memories ,
74
+ ctx->mkldnn_primitives ,
75
+ ctx->mkldnn_scratchpad_mds ,
76
+ avg_pool_desc,
77
+ deps,
78
+ avg_pool_index);
79
+ }
80
+ cpu::mkldnn_utils::set_memory_ptr (
81
+ ctx, deps[0 ], ctx->buffer_data [arg0_buffer_index]);
82
+ cpu::mkldnn_utils::set_memory_ptr (
83
+ ctx, deps[1 ], ctx->buffer_data [out_buffer_index]);
84
+
85
+ cpu::mkldnn_utils::mkldnn_invoke_primitive (
86
+ ctx,
87
+ avg_pool_index,
88
+ deps,
89
+ cpu::mkldnn_utils::OpType::AVGPOOL,
90
+ scratchpad_size);
91
+ };
84
92
functors.emplace_back (functor);
85
93
}
86
94
else
@@ -145,7 +153,8 @@ namespace ngraph
145
153
auto avg_pool_desc =
146
154
mkldnn_emitter->get_avg_pooling_backward_desc <ngraph::op::AvgPoolBackprop>(
147
155
node);
148
- QUERY_SCRATCHPAD_2ARGS (avg_pooling_backward, avg_pool_fwd_desc, avg_pool_desc);
156
+ size_t scratchpad_size = QUERY_SCRATCHPAD_2ARGS (
157
+ avg_pooling_backward, avg_pool_fwd_desc, avg_pool_desc);
149
158
150
159
// AvgPoolBackprop needs 3 primitives: input, result, and pooling_backward.
151
160
size_t avg_pool_index = mkldnn_emitter->reserve_primitive_space (3 );
@@ -155,6 +164,7 @@ namespace ngraph
155
164
avg_pool_desc,
156
165
avg_pool_fwd_desc,
157
166
avg_pool_index,
167
+ scratchpad_size,
158
168
delta_buffer_index,
159
169
out_buffer_index](CPURuntimeContext* ctx,
160
170
CPUExecutionContext* ectx) {
@@ -174,7 +184,11 @@ namespace ngraph
174
184
ctx, deps[1 ], ctx->buffer_data [out_buffer_index]);
175
185
176
186
cpu::mkldnn_utils::mkldnn_invoke_primitive (
177
- ctx, avg_pool_index, deps, cpu::mkldnn_utils::OpType::AVGPOOLBACKPROP);
187
+ ctx,
188
+ avg_pool_index,
189
+ deps,
190
+ cpu::mkldnn_utils::OpType::AVGPOOLBACKPROP,
191
+ scratchpad_size);
178
192
};
179
193
functors.emplace_back (functor);
180
194
}
0 commit comments