Skip to content

Commit cf17e35

Browse files
committed
fix bug
1 parent 1c470af commit cf17e35

8 files changed

+29
-81
lines changed

paddle/phi/kernels/cpu/clip_tensor_grad_kernel.cc

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -27,14 +27,8 @@ void ClipTensorGradKernel(const Context& dev_ctx,
2727
const DenseTensor& max,
2828
const DenseTensor& out_grad,
2929
DenseTensor* x_grad) {
30-
DenseTensor tem_min;
31-
MetaTensor meta_tem_min(&tem_min);
32-
CastInferMeta(min, x.dtype(), &meta_tem_min);
33-
CastKernel<T, Context>(dev_ctx, min, x.dtype(), &tem_min);
34-
DenseTensor tem_max;
35-
MetaTensor meta_tem_max(&tem_max);
36-
CastInferMeta(max, x.dtype(), &meta_tem_max);
37-
CastKernel<T, Context>(dev_ctx, max, x.dtype(), &tem_max);
30+
DenseTensor tem_min = phi::Cast<T, Context>(dev_ctx, min, x.dtype());
31+
DenseTensor tem_max = phi::Cast<T, Context>(dev_ctx, max, x.dtype());
3832

3933
const T* x_data = x.data<T>();
4034
const T* min_data = tem_min.data<T>();

paddle/phi/kernels/cpu/clip_tensor_kernel.cc

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -27,14 +27,8 @@ void ClipTensorKernel(const Context& dev_ctx,
2727
const DenseTensor& min,
2828
const DenseTensor& max,
2929
DenseTensor* out) {
30-
DenseTensor tem_min;
31-
MetaTensor meta_tem_min(&tem_min);
32-
CastInferMeta(min, x.dtype(), &meta_tem_min);
33-
CastKernel<T, Context>(dev_ctx, min, x.dtype(), &tem_min);
34-
DenseTensor tem_max;
35-
MetaTensor meta_tem_max(&tem_max);
36-
CastInferMeta(max, x.dtype(), &meta_tem_max);
37-
CastKernel<T, Context>(dev_ctx, max, x.dtype(), &tem_max);
30+
DenseTensor tem_min = phi::Cast<T, Context>(dev_ctx, min, x.dtype());
31+
DenseTensor tem_max = phi::Cast<T, Context>(dev_ctx, max, x.dtype());
3832

3933
const T* x_data = x.data<T>();
4034
const T* min_data = tem_min.data<T>();

paddle/phi/kernels/gpu/clip_tensor_grad_kernel.cu

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -44,14 +44,8 @@ void ClipTensorGradKernel(const Context& dev_ctx,
4444
const DenseTensor& max,
4545
const DenseTensor& out_grad,
4646
DenseTensor* x_grad) {
47-
DenseTensor tem_min;
48-
MetaTensor meta_tem_min(&tem_min);
49-
CastInferMeta(min, x.dtype(), &meta_tem_min);
50-
CastKernel<T, Context>(dev_ctx, min, x.dtype(), &tem_min);
51-
DenseTensor tem_max;
52-
MetaTensor meta_tem_max(&tem_max);
53-
CastInferMeta(max, x.dtype(), &meta_tem_max);
54-
CastKernel<T, Context>(dev_ctx, max, x.dtype(), &tem_max);
47+
DenseTensor tem_min = phi::Cast<T, Context>(dev_ctx, min, x.dtype());
48+
DenseTensor tem_max = phi::Cast<T, Context>(dev_ctx, max, x.dtype());
5549

5650
const T* x_data = x.data<T>();
5751
auto numel = x.numel();

paddle/phi/kernels/gpu/clip_tensor_kernel.cu

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
// See the License for the specific language governing permissions and
1313
// limitations under the License.
1414

15-
#include "paddle/phi/kernels/clip_kernel.h"
15+
#include "paddle/phi/kernels/clip_tensor_kernel.h"
1616

1717
#include "paddle/phi/backends/gpu/gpu_context.h"
1818
#include "paddle/phi/backends/gpu/gpu_launch_config.h"
@@ -39,14 +39,8 @@ void ClipTensorKernel(const Context& dev_ctx,
3939
const DenseTensor& min,
4040
const DenseTensor& max,
4141
DenseTensor* out) {
42-
DenseTensor tem_min;
43-
MetaTensor meta_tem_min(&tem_min);
44-
CastInferMeta(min, x.dtype(), &meta_tem_min);
45-
CastKernel<T, Context>(dev_ctx, min, x.dtype(), &tem_min);
46-
DenseTensor tem_max;
47-
MetaTensor meta_tem_max(&tem_max);
48-
CastInferMeta(max, x.dtype(), &meta_tem_max);
49-
CastKernel<T, Context>(dev_ctx, max, x.dtype(), &tem_max);
42+
DenseTensor tem_min = phi::Cast<T, Context>(dev_ctx, min, x.dtype());
43+
DenseTensor tem_max = phi::Cast<T, Context>(dev_ctx, max, x.dtype());
5044

5145
std::vector<const DenseTensor*> ins = {&x, &tem_min, &tem_max};
5246
std::vector<DenseTensor*> outs = {out};

paddle/phi/kernels/onednn/clip_tensor_grad_kernel.cc

Lines changed: 5 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -13,11 +13,12 @@
1313
// limitations under the License.
1414

1515
#include "paddle/phi/kernels/clip_tensor_grad_kernel.h"
16-
#include "paddle/phi/kernels/cast_kernel.h"
17-
#include "paddle/phi/kernels/elementwise_kernel.h"
1816

1917
#include "paddle/phi/backends/onednn/onednn_reuse.h"
2018
#include "paddle/phi/core/kernel_registry.h"
19+
#include "paddle/phi/infermeta/unary.h"
20+
#include "paddle/phi/kernels/cast_kernel.h"
21+
#include "paddle/phi/kernels/elementwise_kernel.h"
2122

2223
namespace phi {
2324
template <typename T, typename Context>
@@ -27,14 +28,6 @@ void ClipTensorGradKernel(const Context& dev_ctx,
2728
const DenseTensor& max,
2829
const DenseTensor& out_grad,
2930
DenseTensor* x_grad) {
30-
DenseTensor ex_min;
31-
MetaTensor meta_min(&ex_min);
32-
CastInferMeta(min, x.dtype(), &meta_min);
33-
DenseTensor ex_max;
34-
MetaTensor meta_max(&ex_max);
35-
CastInferMeta(max, x.dtype(), &meta_max);
36-
phi::CastKernel<T, Context>(dev_ctx, min, x.dtype(), &ex_min);
37-
phi::CastKernel<T, Context>(dev_ctx, max, x.dtype(), &ex_max);
3831

3932
const auto& onednn_engine = dev_ctx.GetEngine();
4033
auto& astream = OneDNNContext::tls().get_stream();
@@ -53,8 +46,8 @@ void ClipTensorGradKernel(const Context& dev_ctx,
5346
auto* tem_max_mask = &t_max_mask;
5447
auto* tem_zero_mask = &t_zero_mask;
5548
auto* non_const_x = &x;
56-
auto* non_const_min = &ex_min;
57-
auto* non_const_max = &ex_max;
49+
auto* non_const_min = &min;
50+
auto* non_const_max = &max;
5851
auto* non_const_out_grad = &out_grad;
5952

6053
funcs::BinaryOneDNNHandler<T> Lesshandler(dnnl::algorithm::binary_lt,

paddle/phi/kernels/onednn/clip_tensor_kernel.cc

Lines changed: 5 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -13,11 +13,12 @@
1313
// limitations under the License.
1414

1515
#include "paddle/phi/kernels/clip_tensor_kernel.h"
16-
#include "paddle/phi/kernels/cast_kernel.h"
17-
#include "paddle/phi/kernels/elementwise_kernel.h"
1816

1917
#include "paddle/phi/backends/onednn/onednn_reuse.h"
2018
#include "paddle/phi/core/kernel_registry.h"
19+
#include "paddle/phi/infermeta/unary.h"
20+
#include "paddle/phi/kernels/cast_kernel.h"
21+
#include "paddle/phi/kernels/elementwise_kernel.h"
2122

2223
namespace phi {
2324
template <typename T, typename Context>
@@ -26,14 +27,6 @@ void ClipTensorKernel(const Context& dev_ctx,
2627
const DenseTensor& min,
2728
const DenseTensor& max,
2829
DenseTensor* out) {
29-
DenseTensor ex_min;
30-
MetaTensor meta_min(&ex_min);
31-
CastInferMeta(min, x.dtype(), &meta_min);
32-
DenseTensor ex_max;
33-
MetaTensor meta_max(&ex_max);
34-
CastInferMeta(max, x.dtype(), &meta_max);
35-
phi::CastKernel<T, Context>(dev_ctx, min, x.dtype(), &ex_min);
36-
phi::CastKernel<T, Context>(dev_ctx, max, x.dtype(), &ex_max);
3730

3831
const auto& onednn_engine = dev_ctx.GetEngine();
3932
auto& astream = OneDNNContext::tls().get_stream();
@@ -43,8 +36,8 @@ void ClipTensorKernel(const Context& dev_ctx,
4336
UnchangedInferMeta(x, &meta_out);
4437
auto* tem_out = &t_out;
4538
auto* non_const_x = &x;
46-
auto* non_const_min = &ex_min;
47-
auto* non_const_max = &ex_max;
39+
auto* non_const_min = &min;
40+
auto* non_const_max = &max;
4841

4942
funcs::BinaryOneDNNHandler<T> MAXhandler(dnnl::algorithm::binary_max,
5043
-1,

paddle/phi/kernels/xpu/clip_tensor_grad_kernel.cc

Lines changed: 5 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
#include "paddle/phi/backends/xpu/enforce_xpu.h"
1818
#include "paddle/phi/core/kernel_registry.h"
19+
#include "paddle/phi/infermeta/unary.h"
1920
#include "paddle/phi/kernels/cast_kernel.h"
2021
#include "paddle/phi/kernels/compare_kernel.h"
2122
#include "paddle/phi/kernels/full_kernel.h"
@@ -31,14 +32,8 @@ void ClipTensorGradKernel(const Context& dev_ctx,
3132
const DenseTensor& max,
3233
const DenseTensor& out_grad,
3334
DenseTensor* x_grad) {
34-
DenseTensor ex_min;
35-
MetaTensor meta_min(&ex_min);
36-
CastInferMeta(min, x.dtype(), &meta_min);
37-
DenseTensor ex_max;
38-
MetaTensor meta_max(&ex_max);
39-
CastInferMeta(max, x.dtype(), &meta_max);
40-
phi::CastKernel<T, Context>(dev_ctx, min, x.dtype(), &ex_min);
41-
phi::CastKernel<T, Context>(dev_ctx, max, x.dtype(), &ex_max);
35+
DenseTensor ex_min = phi::Cast<T, Context>(dev_ctx, min, x.dtype());
36+
DenseTensor ex_max = phi::Cast<T, Context>(dev_ctx, max, x.dtype());
4237

4338
phi::DenseTensor x_ls_min;
4439
MetaTensor meta_x_ls_min(&x_ls_min);
@@ -56,12 +51,12 @@ void ClipTensorGradKernel(const Context& dev_ctx,
5651
MetaTensor meta_out(&out);
5752
UnchangedExceptDtypeInferMeta(x, &meta_out);
5853
meta_out.set_dtype(phi::DataType::BOOL);
59-
LogicalAndKernel<bool, Context>(dev_ctx, x_ls_min, x_ls_max, &out);
54+
phi::LogicalAndKernel<bool, Context>(dev_ctx, x_ls_min, x_ls_max, &out);
6055

6156
phi::DenseTensor zero_tensor;
6257
MetaTensor meta_zero(&zero_tensor);
6358
UnchangedInferMeta(x_grad, &meta_zero);
64-
FullKernel<T, Context>(dev_ctx,
59+
phi::FullKernel<T, Context>(dev_ctx,
6560
common::vectorize(x_grad->dims()),
6661
0.0f,
6762
zero_tensor.dtype(),

paddle/phi/kernels/xpu/clip_tensor_kernel.cc

Lines changed: 5 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -29,19 +29,10 @@ void ClipTensorKernel(const Context& dev_ctx,
2929
const DenseTensor& min,
3030
const DenseTensor& max,
3131
DenseTensor* out) {
32-
DenseTensor tem_min;
33-
MetaTensor meta_tem_min(&tem_min);
34-
CastInferMeta(min, x.dtype(), &meta_tem_min);
35-
CastKernel<T, Context>(dev_ctx, min, x.dtype(), &tem_min);
36-
DenseTensor tem_max;
37-
MetaTensor meta_tem_max(&tem_max);
38-
CastInferMeta(max, x.dtype(), &meta_tem_max);
39-
CastKernel<T, Context>(dev_ctx, max, x.dtype(), &tem_max);
32+
DenseTensor tem_min = phi::Cast<T, Context>(dev_ctx, min, x.dtype());
33+
DenseTensor tem_max = phi::Cast<T, Context>(dev_ctx, max, x.dtype());
4034

41-
DenseTensor tem_max_out;
42-
MetaTensor meta_tem_max_out(&tem_max_out);
43-
ElementwiseInferMeta(min, x, &meta_tem_max_out);
44-
MaximumKernel<T, Context>(dev_ctx, min, x, &tem_max_out);
35+
DenseTensor tem_max_out = phi::Maximum<T, Context>(dev_ctx, min, x);
4536
MinimumKernel<T, Context>(dev_ctx, tem_max_out, max, out);
4637
}
4738

@@ -54,5 +45,5 @@ PD_REGISTER_KERNEL(clip_tensor,
5445
float,
5546
phi::dtype::float16,
5647
phi::dtype::bfloat16,
57-
int64_t,
58-
int) {}
48+
int,
49+
int64_t) {}

0 commit comments

Comments
 (0)