Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
a162837 committed Jan 31, 2025
1 parent 7adbdbe commit e4a4c12
Showing 1 changed file with 6 additions and 6 deletions.
12 changes: 6 additions & 6 deletions paddle/phi/kernels/onednn/clip_tensor_grad_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -132,23 +132,23 @@ void ClipTensorGradKernel(const Context& dev_ctx,
phi::DenseTensor x_ls_min;
MetaTensor meta_x_ls_min(&x_ls_min);
UnchangedInferMeta(x, &meta_x_ls_min);
BinaryFun<T, dnnl::algorithm::binary_lt>(dev_ctx, min, x, -1, *x_ls_min);
BinaryFun<T, dnnl::algorithm::binary_lt>(dev_ctx, min, x, -1, &x_ls_min);
phi::DenseTensor cast_x_ls_min;
cast_x_ls_min = phi::Cast(dev_ctx, *x_ls_min, x.dtype());
cast_x_ls_min = phi::Cast(dev_ctx, x_ls_min, x.dtype());

phi::DenseTensor x_ls_max;
MetaTensor meta_x_ls_max(&x_ls_max);
UnchangedInferMeta(x, &meta_x_ls_max);
BinaryFun<T, dnnl::algorithm::binary_lt>(dev_ctx, x, max, -1, *x_ls_max);
BinaryFun<T, dnnl::algorithm::binary_lt>(dev_ctx, x, max, -1, &x_ls_max);
phi::DenseTensor cast_x_ls_max;
cast_x_ls_max = phi::Cast(dev_ctx, *x_ls_max, x.dtype());
cast_x_ls_max = phi::Cast(dev_ctx, x_ls_max, x.dtype());

phi::DenseTensor mask_zero;
MetaTensor meta_mask_zero(&mask_zero);
UnchangedInferMeta(x, &meta_mask_zero);
BinaryFun<T, dnnl::algorithm::binary_mul>(dev_ctx, *cast_x_ls_min, *cast_x_ls_max, -1, *mask_zero);
BinaryFun<T, dnnl::algorithm::binary_mul>(dev_ctx, cast_x_ls_min, cast_x_ls_max, -1, &mask_zero);

BinaryFun<T, dnnl::algorithm::binary_mul>(dev_ctx, *mask_zero, out_grad, -1, x_grad);
BinaryFun<T, dnnl::algorithm::binary_mul>(dev_ctx, mask_zero, out_grad, -1, x_grad);
}
} // namespace phi

Expand Down

0 comments on commit e4a4c12

Please sign in to comment.