16
16
17
17
#include " paddle/phi/backends/xpu/enforce_xpu.h"
18
18
#include " paddle/phi/core/kernel_registry.h"
19
+ #include " paddle/phi/infermeta/unary.h"
19
20
#include " paddle/phi/kernels/cast_kernel.h"
20
21
#include " paddle/phi/kernels/compare_kernel.h"
21
22
#include " paddle/phi/kernels/full_kernel.h"
@@ -31,14 +32,8 @@ void ClipTensorGradKernel(const Context& dev_ctx,
31
32
const DenseTensor& max,
32
33
const DenseTensor& out_grad,
33
34
DenseTensor* x_grad) {
34
- DenseTensor ex_min;
35
- MetaTensor meta_min (&ex_min);
36
- CastInferMeta (min, x.dtype (), &meta_min);
37
- DenseTensor ex_max;
38
- MetaTensor meta_max (&ex_max);
39
- CastInferMeta (max, x.dtype (), &meta_max);
40
- phi::CastKernel<T, Context>(dev_ctx, min, x.dtype (), &ex_min);
41
- phi::CastKernel<T, Context>(dev_ctx, max, x.dtype (), &ex_max);
35
+ DenseTensor ex_min = phi::Cast<T, Context>(dev_ctx, min, x.dtype ());
36
+ DenseTensor ex_max = phi::Cast<T, Context>(dev_ctx, max, x.dtype ());
42
37
43
38
phi::DenseTensor x_ls_min;
44
39
MetaTensor meta_x_ls_min (&x_ls_min);
@@ -56,12 +51,12 @@ void ClipTensorGradKernel(const Context& dev_ctx,
56
51
MetaTensor meta_out (&out);
57
52
UnchangedExceptDtypeInferMeta (x, &meta_out);
58
53
meta_out.set_dtype (phi::DataType::BOOL);
59
- LogicalAndKernel<bool , Context>(dev_ctx, x_ls_min, x_ls_max, &out);
54
+ phi:: LogicalAndKernel<bool , Context>(dev_ctx, x_ls_min, x_ls_max, &out);
60
55
61
56
phi::DenseTensor zero_tensor;
62
57
MetaTensor meta_zero (&zero_tensor);
63
58
UnchangedInferMeta (x_grad, &meta_zero);
64
- FullKernel<T, Context>(dev_ctx,
59
+ phi:: FullKernel<T, Context>(dev_ctx,
65
60
common::vectorize (x_grad->dims ()),
66
61
0 .0f ,
67
62
zero_tensor.dtype (),
0 commit comments