Skip to content

Commit 173660b

Browse files
authored
[oneDNN] Cache oneDNN stream not to recreate in each oneDNN op (PaddlePaddle#30358)
1 parent ae0f88a commit 173660b

26 files changed

+81
-55
lines changed

paddle/fluid/framework/data_layout_transform.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -193,7 +193,7 @@ void innerTransDataLayoutFromMKLDNN(DataLayout in_layout, DataLayout out_layout,
193193
auto reorder_p =
194194
handler.AcquireReorder(reorder_dst_memory_p, reorder_src_memory_p);
195195

196-
mkldnn::stream astream(cpu_engine);
196+
auto& astream = platform::MKLDNNDeviceContext::tls().get_stream();
197197
platform::RecordEvent record_reorder("ext_reorder",
198198
platform::EventRole::kUniqueOp);
199199
reorder_p->execute(astream, *reorder_src_memory_p, *reorder_dst_memory_p);

paddle/fluid/operators/elementwise/mkldnn/elementwise_add_mkldnn_op.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ class EltwiseAddMKLDNNGradKernel : public ElemwiseGradKernel<T> {
4848
platform::ReorderMKLDNNHandler handler(tz, dout->type(), dout_type, dev_ctx,
4949
onednn_engine, key);
5050

51-
mkldnn::stream astream(onednn_engine);
51+
auto& astream = platform::MKLDNNDeviceContext::tls().get_stream();
5252
auto reorder_src_memory_p = handler.AcquireSrcMemory(
5353
dout->format(), platform::to_void_cast(dout->data<T>()));
5454

paddle/fluid/operators/elementwise/mkldnn/elementwise_mkldnn_op.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ class EltwiseMKLDNNKernel : public framework::OpKernel<T> {
6868

6969
const auto binary_prim = handler.AcquireForwardPrimitive();
7070

71-
mkldnn::stream astream(mkldnn_engine);
71+
auto& astream = platform::MKLDNNDeviceContext::tls().get_stream();
7272

7373
const std::unordered_map<int, dnnl::memory> args = {
7474
{DNNL_ARG_SRC_0, *src_x_memory},

paddle/fluid/operators/fused/mkldnn/fusion_gru_mkldnn_op.cc

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -246,7 +246,7 @@ class GRUMKLDNNHandler : public platform::MKLDNNHandlerT<T, dnnl::gru_forward> {
246246
memory_p = std::make_shared<dnnl::memory>(this->fwd_pd_->src_iter_desc(),
247247
this->engine_);
248248

249-
dnnl::stream astream(this->engine_);
249+
auto& astream = paddle::platform::MKLDNNDeviceContext::tls().get_stream();
250250
dnnl::reorder(user_h0_memory, *memory_p, attr_)
251251
.execute(astream, user_h0_memory, *memory_p);
252252

@@ -284,7 +284,7 @@ class GRUMKLDNNHandler : public platform::MKLDNNHandlerT<T, dnnl::gru_forward> {
284284
memory_p = std::make_shared<dnnl::memory>(
285285
this->fwd_pd_->weights_layer_desc(), this->engine_);
286286

287-
dnnl::stream astream(this->engine_);
287+
auto& astream = paddle::platform::MKLDNNDeviceContext::tls().get_stream();
288288
dnnl::reorder(user_memory, *memory_p, attr_)
289289
.execute(astream, user_memory, *memory_p);
290290

@@ -337,7 +337,7 @@ class GRUMKLDNNHandler : public platform::MKLDNNHandlerT<T, dnnl::gru_forward> {
337337
memory_p = std::make_shared<dnnl::memory>(
338338
this->fwd_pd_->weights_iter_desc(), this->engine_);
339339

340-
dnnl::stream astream(this->engine_);
340+
auto& astream = paddle::platform::MKLDNNDeviceContext::tls().get_stream();
341341
dnnl::reorder(user_memory, *memory_p, attr_)
342342
.execute(astream, user_memory, *memory_p);
343343

@@ -469,7 +469,7 @@ class FusionGRUMKLDNNKernel : public framework::OpKernel<T> {
469469

470470
auto gru_forward_p = handler.AcquireForwardPrimitive();
471471

472-
dnnl::stream astream(mkldnn_engine);
472+
auto& astream = platform::MKLDNNDeviceContext::tls().get_stream();
473473
gru_forward_p->execute(astream, gru_args);
474474
astream.wait();
475475

paddle/fluid/operators/fused/mkldnn/multi_gru_mkldnn_op.cc

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -292,7 +292,7 @@ class MultiGRUHandler {
292292

293293
auto gru_forward_p0 = AcquireGruPrimitive(layer, dir);
294294

295-
dnnl::stream astream(engine_);
295+
auto& astream = paddle::platform::MKLDNNDeviceContext::tls().get_stream();
296296
gru_forward_p0->execute(astream, gru_args);
297297
astream.wait();
298298
return out_mem;
@@ -315,7 +315,7 @@ class MultiGRUHandler {
315315
memory_p = std::make_shared<dnnl::memory>(
316316
gru_pds_[{layer, dir}]->src_iter_desc(), engine_);
317317

318-
dnnl::stream astream(engine_);
318+
auto& astream = paddle::platform::MKLDNNDeviceContext::tls().get_stream();
319319
dnnl::reorder(user_h0_memory, *memory_p, attrs_[2 * layer + (dir == R2L)])
320320
.execute(astream, user_h0_memory, *memory_p);
321321

@@ -354,7 +354,7 @@ class MultiGRUHandler {
354354
memory_p = std::make_shared<dnnl::memory>(
355355
gru_pds_[{layer, dir}]->weights_layer_desc(), engine_);
356356

357-
dnnl::stream astream(engine_);
357+
auto& astream = paddle::platform::MKLDNNDeviceContext::tls().get_stream();
358358
dnnl::reorder(user_memory, *memory_p, attrs_[2 * layer + (dir == R2L)])
359359
.execute(astream, user_memory, *memory_p);
360360

@@ -410,7 +410,7 @@ class MultiGRUHandler {
410410
memory_p = std::make_shared<dnnl::memory>(
411411
gru_pds_[{layer, dir}]->weights_iter_desc(), engine_);
412412

413-
dnnl::stream astream(engine_);
413+
auto& astream = paddle::platform::MKLDNNDeviceContext::tls().get_stream();
414414
dnnl::reorder(user_memory, *memory_p, attrs_[2 * layer + (dir == R2L)])
415415
.execute(astream, user_memory, *memory_p);
416416

@@ -516,7 +516,7 @@ class MultiGRUHandler {
516516

517517
auto concat_p = AcquireConcatPrimitive(layer);
518518

519-
dnnl::stream astream(engine_);
519+
auto& astream = platform::MKLDNNDeviceContext::tls().get_stream();
520520
concat_p->execute(astream, concat_args);
521521
astream.wait();
522522
return out_mem;

paddle/fluid/operators/mkldnn/activation_mkldnn_op.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -112,7 +112,7 @@ void eltwise_forward(const framework::ExecutionContext &ctx,
112112
auto dst_memory_p = is_inplaced ? src_memory_p : handler.AcquireDstMemory(y);
113113
auto activation_p = handler.AcquireForwardPrimitive();
114114

115-
mkldnn::stream astream(dev_ctx.GetEngine());
115+
auto &astream = paddle::platform::MKLDNNDeviceContext::tls().get_stream();
116116
activation_p->execute(astream, {{MKLDNN_ARG_FROM, *src_memory_p},
117117
{MKLDNN_ARG_TO, *dst_memory_p}});
118118
astream.wait();
@@ -158,7 +158,7 @@ void eltwise_grad(const framework::ExecutionContext &ctx,
158158
auto diff_src_memory_p = handler.AcquireDiffSrcMemory(diff_x);
159159
auto activation_backward_p = handler.AcquireBackwardPrimitive();
160160

161-
mkldnn::stream astream(dev_ctx.GetEngine());
161+
auto &astream = paddle::platform::MKLDNNDeviceContext::tls().get_stream();
162162
activation_backward_p->execute(astream,
163163
{{MKLDNN_ARG_SRC, *src_memory_p},
164164
{MKLDNN_ARG_DIFF_DST, *diff_dst_memory_p},

paddle/fluid/operators/mkldnn/batch_norm_mkldnn_op.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -220,7 +220,7 @@ class BatchNormMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
220220
y->set_layout(DataLayout::kMKLDNN);
221221
y->set_format(platform::GetMKLDNNFormat(*dst_memory));
222222

223-
mkldnn::stream astream(dev_ctx.GetEngine());
223+
auto &astream = platform::MKLDNNDeviceContext::tls().get_stream();
224224
batch_norm_p->execute(astream,
225225
{{MKLDNN_ARG_SRC, *src_memory},
226226
{MKLDNN_ARG_SCALE_SHIFT, *scaleshift_memory},
@@ -321,7 +321,7 @@ class BatchNormMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
321321
// finally create batch_norm backward primitive
322322
auto batch_norm_bwd_p = handler.AcquireBackwardPrimitive();
323323

324-
mkldnn::stream astream(dev_ctx.GetEngine());
324+
auto &astream = platform::MKLDNNDeviceContext::tls().get_stream();
325325
batch_norm_bwd_p->execute(
326326
astream, {{MKLDNN_ARG_SRC, *src_memory},
327327
{MKLDNN_ARG_MEAN, *mean_memory},

paddle/fluid/operators/mkldnn/concat_mkldnn_op.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -202,7 +202,7 @@ class ConcatMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
202202
output->mutable_data<T>(place, concat_pd->dst_desc().get_size()));
203203
}
204204

205-
mkldnn::stream astream(mkldnn_engine);
205+
auto& astream = platform::MKLDNNDeviceContext::tls().get_stream();
206206
std::unordered_map<int, memory> args;
207207
for (size_t i = 0; i < multi_input.size(); ++i) {
208208
args.insert({MKLDNN_ARG_MULTIPLE_SRC + i, (*srcs).at(i)});

paddle/fluid/operators/mkldnn/conv_mkldnn_op.cc

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -471,7 +471,7 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
471471
args.insert({MKLDNN_ARG_BIAS, *bias_memory_p});
472472
}
473473

474-
mkldnn::stream astream(mkldnn_engine);
474+
auto& astream = platform::MKLDNNDeviceContext::tls().get_stream();
475475
conv_p->execute(astream, args);
476476
astream.wait();
477477

@@ -553,7 +553,7 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
553553
conv_p = std::static_pointer_cast<mkldnn::convolution_forward>(
554554
dev_ctx.GetBlob(prim_key));
555555

556-
mkldnn::stream astream(mkldnn_engine);
556+
auto& astream = platform::MKLDNNDeviceContext::tls().get_stream();
557557

558558
if (conv_p == nullptr || !is_test) {
559559
float fuse_alpha = ctx.Attr<float>("fuse_alpha");
@@ -1045,7 +1045,7 @@ class ConvMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
10451045
user_weights_md, to_void_cast<T>(filter_data));
10461046
auto user_diff_dst_memory_p = handler.AcquireDiffDstMemory(
10471047
user_diff_dst_md, to_void_cast<T>(output_grad_data));
1048-
mkldnn::stream astream(mkldnn_engine);
1048+
auto& astream = platform::MKLDNNDeviceContext::tls().get_stream();
10491049
if (filter_grad) {
10501050
auto src_memory_p = handler.AcquireSrcMemoryFromWeightsPrimitive(
10511051
user_src_memory_p, pipeline);

paddle/fluid/operators/mkldnn/conv_transpose_mkldnn_op.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -242,7 +242,7 @@ class ConvTransposeMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
242242

243243
auto conv_p = handler.AcquireConvolution();
244244

245-
mkldnn::stream astream(mkldnn_engine);
245+
auto& astream = platform::MKLDNNDeviceContext::tls().get_stream();
246246
if (bias) {
247247
const T* bias_data = bias->data<T>();
248248
auto user_bias_md = platform::MKLDNNMemDesc(

paddle/fluid/operators/mkldnn/dequantize_mkldnn_op.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -124,7 +124,7 @@ class DeQuantOpKernel : public framework::OpKernel<T> {
124124
dst_memory->set_data_handle(output->mutable_data<float>(ctx.GetPlace()));
125125
}
126126

127-
mkldnn::stream astream(engine);
127+
auto& astream = platform::MKLDNNDeviceContext::tls().get_stream();
128128
reorder_p->execute(astream, *src_memory, *dst_memory);
129129
astream.wait();
130130

paddle/fluid/operators/mkldnn/fc_mkldnn_op.cc

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -137,7 +137,7 @@ class FCPrimitiveFactory {
137137
}
138138

139139
void Execute() {
140-
mkldnn::stream astream(engine_);
140+
auto& astream = platform::MKLDNNDeviceContext::tls().get_stream();
141141
if (bias_) {
142142
fc_->execute(astream, {{MKLDNN_ARG_SRC, *input_},
143143
{MKLDNN_ARG_WEIGHTS, *weights_},
@@ -280,7 +280,7 @@ class FCPrimitiveFactory {
280280
auto dst_mem = std::make_shared<memory>(dst_desc, engine_);
281281

282282
auto reorder = mkldnn::reorder(src_mem, *dst_mem);
283-
mkldnn::stream astream(engine_);
283+
auto& astream = platform::MKLDNNDeviceContext::tls().get_stream();
284284

285285
{
286286
platform::RecordEvent record_reorder("int_reorder",
@@ -309,7 +309,7 @@ class FCPrimitiveFactory {
309309
attributes.set_output_scales(mask, scale_data);
310310
auto reorder = mkldnn::reorder(*src_mem, *dst_mem, attributes);
311311

312-
mkldnn::stream astream(engine_);
312+
auto& astream = platform::MKLDNNDeviceContext::tls().get_stream();
313313
{
314314
platform::RecordEvent record_reorder("int_reorder",
315315
platform::EventRole::kUniqueOp);

paddle/fluid/operators/mkldnn/interpolate_mkldnn_op.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -154,7 +154,7 @@ class InterpolateMKLDNNKernel : public framework::OpKernel<T> {
154154
auto resampling_prim = handler.AcquireForwardPrimitive();
155155
const std::unordered_map<int, dnnl::memory> args = {
156156
{DNNL_ARG_SRC, *src_memory_p}, {DNNL_ARG_DST, *dst_memory_p}};
157-
mkldnn::stream astream(mkldnn_engine);
157+
auto& astream = platform::MKLDNNDeviceContext::tls().get_stream();
158158
resampling_prim->execute(astream, args);
159159
astream.wait();
160160

paddle/fluid/operators/mkldnn/layer_norm_mkldnn_op.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,7 @@ class LayerNormMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
120120

121121
auto layer_norm_p = handler.AcquireForwardPrimitive();
122122

123-
dnnl::stream astream(dev_ctx.GetEngine());
123+
auto& astream = platform::MKLDNNDeviceContext::tls().get_stream();
124124
std::unordered_map<int, dnnl::memory> args;
125125

126126
args.insert({DNNL_ARG_SRC, *src_memory});

paddle/fluid/operators/mkldnn/lrn_mkldnn_op.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ class LRNMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
5959
auto workspace_memory = handler.AcquireWorkspaceMemory(mid);
6060
mid->set_layout(framework::DataLayout::kMKLDNN);
6161

62-
mkldnn::stream astream(dev_ctx.GetEngine());
62+
auto& astream = platform::MKLDNNDeviceContext::tls().get_stream();
6363
if (!workspace_memory->get_desc().is_zero()) {
6464
mid->set_format(platform::GetMKLDNNFormat(*workspace_memory));
6565
lrn_p->execute(astream, {{MKLDNN_ARG_SRC, *src_memory},
@@ -118,7 +118,7 @@ class LRNMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
118118

119119
auto lrn_bwd = handler.AcquireBackwardPrimitive();
120120

121-
mkldnn::stream astream(dev_ctx.GetEngine());
121+
auto& astream = platform::MKLDNNDeviceContext::tls().get_stream();
122122
lrn_bwd->execute(astream, {{MKLDNN_ARG_SRC, *src_memory},
123123
{MKLDNN_ARG_DIFF_DST, *diff_dst_memory},
124124
{MKLDNN_ARG_DIFF_SRC, *diff_src_memory},

paddle/fluid/operators/mkldnn/mul_mkldnn_op.cc

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -109,7 +109,7 @@ class MulPrimitiveFactory {
109109

110110
auto reorder = mkldnn::reorder(reorder_pd);
111111

112-
mkldnn::stream astream(engine_);
112+
auto &astream = platform::MKLDNNDeviceContext::tls().get_stream();
113113
{
114114
platform::RecordEvent record_reorder("int_reorder",
115115
platform::EventRole::kUniqueOp);
@@ -184,7 +184,7 @@ class MulPrimitiveFactory {
184184
}
185185

186186
void Execute() {
187-
mkldnn::stream astream(engine_);
187+
auto &astream = platform::MKLDNNDeviceContext::tls().get_stream();
188188
(*mul_).execute(astream, {{MKLDNN_ARG_SRC, *x_input_},
189189
{MKLDNN_ARG_WEIGHTS, *y_input_},
190190
{MKLDNN_ARG_DST, *output_}});
@@ -270,8 +270,7 @@ class MulPrimitiveFactory {
270270

271271
auto reorder = mkldnn::reorder(src_mem, dst_mem);
272272

273-
mkldnn::stream astream(engine_);
274-
273+
auto &astream = platform::MKLDNNDeviceContext::tls().get_stream();
275274
{
276275
platform::RecordEvent record_reorder("int_reorder",
277276
platform::EventRole::kUniqueOp);
@@ -355,7 +354,7 @@ class MulMKLDNNKernel : public framework::OpKernel<XT> {
355354
"Operator DNNL Mul must use CPUPlace"));
356355
platform::MKLDNNDeviceContext::tls().log_lib_version();
357356
auto &dev_ctx = ctx.template device_context<MKLDNNDeviceContext>();
358-
const auto &mkldnn_engine = dev_ctx.GetEngine();
357+
auto &mkldnn_engine = dev_ctx.GetEngine();
359358

360359
const Tensor *x = ctx.Input<Tensor>("X");
361360
const Tensor *y = ctx.Input<Tensor>("Y");

paddle/fluid/operators/mkldnn/pool_mkldnn_op.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ class PoolMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
5151

5252
auto pool_p = handler.AcquireForwardPrimitive();
5353

54-
mkldnn::stream astream(dev_ctx.GetEngine());
54+
auto& astream = platform::MKLDNNDeviceContext::tls().get_stream();
5555
if ((ctx.Attr<bool>("is_test") == false) &&
5656
(ctx.Attr<std::string>("pooling_type") == "max")) {
5757
// Training
@@ -154,7 +154,7 @@ class PoolMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
154154

155155
auto pool_bwd_p = handler.AcquireBackwardPrimitive();
156156

157-
mkldnn::stream astream(dev_ctx.GetEngine());
157+
auto& astream = platform::MKLDNNDeviceContext::tls().get_stream();
158158
if (pooling_type == "max") {
159159
// Max - pooling needs Workspace
160160
auto workspace_memory = handler.AcquireWorkspaceMemory();

paddle/fluid/operators/mkldnn/quantize_mkldnn_op.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -140,7 +140,7 @@ class QuantOpKernel : public framework::OpKernel<T> {
140140
}
141141
}
142142

143-
mkldnn::stream astream(engine);
143+
auto& astream = platform::MKLDNNDeviceContext::tls().get_stream();
144144
{
145145
platform::RecordEvent record_reorder("int_reorder",
146146
platform::EventRole::kUniqueOp);

paddle/fluid/operators/mkldnn/requantize_mkldnn_op.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -137,7 +137,7 @@ class ReQuantOpKernel : public framework::OpKernel<T> {
137137
}
138138
}
139139

140-
dnnl::stream astream(engine);
140+
auto& astream = platform::MKLDNNDeviceContext::tls().get_stream();
141141
{
142142
platform::RecordEvent record_reorder("int_reorder",
143143
platform::EventRole::kUniqueOp);

paddle/fluid/operators/mkldnn/softmax_mkldnn_op.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -117,7 +117,7 @@ class SoftmaxMKLDNNKernel : public paddle::framework::OpKernel<T> {
117117

118118
auto softmax_p = handler.AcquireForwardPrimitive();
119119

120-
mkldnn::stream astream(dev_ctx.GetEngine());
120+
auto& astream = paddle::platform::MKLDNNDeviceContext::tls().get_stream();
121121
softmax_p->execute(astream, {{DNNL_ARG_SRC, *softmax_src_memory_p},
122122
{DNNL_ARG_DST, *softmax_dst_memory_p}});
123123
astream.wait();
@@ -169,7 +169,7 @@ class SoftmaxMKLDNNGradKernel : public paddle::framework::OpKernel<T> {
169169

170170
auto softmax_bwd_p = handler.AcquireBackwardPrimitive();
171171

172-
mkldnn::stream astream(dev_ctx.GetEngine());
172+
auto& astream = platform::MKLDNNDeviceContext::tls().get_stream();
173173
softmax_bwd_p->execute(astream,
174174
{{MKLDNN_ARG_DST, *dst_memory_p},
175175
{MKLDNN_ARG_DIFF_DST, *diff_dst_memory_p},

paddle/fluid/operators/mkldnn/sum_mkldnn_op.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -178,7 +178,7 @@ class SumMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
178178
}
179179
args.insert({MKLDNN_ARG_DST, *dst_mem});
180180

181-
mkldnn::stream astream(dev_ctx.GetEngine());
181+
auto& astream = platform::MKLDNNDeviceContext::tls().get_stream();
182182
sum_p->execute(astream, args);
183183
astream.wait();
184184

paddle/fluid/operators/mkldnn/transpose_mkldnn_op.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ class TransposeMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
6161
auto transpose_p = handler.AcquireTranspose(transpose_dst_memory_p,
6262
transpose_src_memory_p);
6363

64-
mkldnn::stream astream(mkldnn_engine);
64+
auto& astream = platform::MKLDNNDeviceContext::tls().get_stream();
6565
transpose_p->execute(astream, *transpose_src_memory_p,
6666
*transpose_dst_memory_p);
6767
astream.wait();
@@ -116,7 +116,7 @@ class TransposeMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
116116
auto transpose_p = handler.AcquireTranspose(transpose_dst_memory_p,
117117
transpose_src_memory_p);
118118

119-
mkldnn::stream astream(mkldnn_engine);
119+
auto& astream = platform::MKLDNNDeviceContext::tls().get_stream();
120120
transpose_p->execute(astream, *transpose_src_memory_p,
121121
*transpose_dst_memory_p);
122122
astream.wait();

0 commit comments

Comments
 (0)