Skip to content

Commit

Permalink
fix code_style
Browse files Browse the repository at this point in the history
  • Loading branch information
a162837 committed Nov 5, 2024
1 parent 930ecc0 commit 8f45f09
Show file tree
Hide file tree
Showing 3 changed files with 262 additions and 120 deletions.
40 changes: 26 additions & 14 deletions python/paddle/tensor/math.py
Original file line number Diff line number Diff line change
Expand Up @@ -3815,10 +3815,14 @@ def clip(
max_n = check_clip_tensor(x, max, max_, value_dtype, 'max')

min_n = (
paddle.broadcast_to(min_n, x.shape) if min_n.shape != x.shape else min_n
paddle.broadcast_to(min_n, x.shape)
if min_n.shape != x.shape
else min_n
)
max_n = (
paddle.broadcast_to(max_n, x.shape) if max_n.shape != x.shape else max_n
paddle.broadcast_to(max_n, x.shape)
if max_n.shape != x.shape
else max_n
)

output_min = paddle.where(x < min_n, min_n, x)
Expand All @@ -3827,17 +3831,19 @@ def clip(

else:
if in_dynamic_or_pir_mode():
if isinstance(min, Variable):
min = min.item(0)
if isinstance(max, Variable):
max = max.item(0)
if isinstance(min, (Variable, paddle.pir.Value)):
min = min.item()
if isinstance(max, (Variable, paddle.pir.Value)):
max = max.item()
min = min_ if min is None else min
max = max_ if max is None else max
return _C_ops.clip(x, min, max)
else:
if min is not None:
check_type(min, 'min', (float, int, Variable), 'clip')
if isinstance(min, Variable):
check_type(
min, 'min', (float, int, Variable, paddle.Tensor), 'clip'
)
if isinstance(min, (Variable, paddle.Tensor, paddle.pir.Value)):
check_dtype(
min.dtype,
'min',
Expand All @@ -3846,8 +3852,10 @@ def clip(
'(When the type of min in clip is Variable.)',
)
if max is not None:
check_type(max, 'max', (float, int, Variable), 'clip')
if isinstance(max, Variable):
check_type(
max, 'max', (float, int, Variable, paddle.Tensor), 'clip'
)
if isinstance(max, (Variable, paddle.Tensor, paddle.pir.Value)):
check_dtype(
max.dtype,
'max',
Expand All @@ -3866,13 +3874,13 @@ def clip(
inputs = {'X': x}
attrs = {'min': min_, 'max': max_}

if isinstance(min, Variable):
if isinstance(min, (Variable, paddle.Tensor, paddle.pir.Value)):
min.stop_gradient = True
inputs['Min'] = min
elif min is not None:
attrs['min'] = min

if isinstance(max, Variable):
if isinstance(max, (Variable, paddle.Tensor, paddle.pir.Value)):
max.stop_gradient = True
inputs['Max'] = max
elif max is not None:
Expand Down Expand Up @@ -3926,10 +3934,14 @@ def clip_(
min = check_clip_tensor(x, min, fmin, x.dtype, 'min')

max_expand = (
paddle.broadcast_to(max, x.shape) if max.shape != x.shape else max
paddle.broadcast_to(max, x.shape)
if max.shape != x.shape
else max
)
min_expand = (
paddle.broadcast_to(min, x.shape) if min.shape != x.shape else min
paddle.broadcast_to(min, x.shape)
if min.shape != x.shape
else min
)

paddle.where_(x > min_expand, x, min_expand)
Expand Down
2 changes: 1 addition & 1 deletion test/legacy_test/test_clip_op.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down
Loading

0 comments on commit 8f45f09

Please sign in to comment.