Skip to content

Commit

Permalink
[Comp] Fix take_along_axis_grad when duplicated entries in indices (#…
Browse files Browse the repository at this point in the history
…70250)

* fix take_along_axis_grad when duplicated entries in indices

* update unitest

* add grad check

* use include_self=True to avoid BUG
  • Loading branch information
HydrogenSulfate authored Dec 18, 2024
1 parent 9986a56 commit c7d49ec
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 2 deletions.
8 changes: 6 additions & 2 deletions paddle/fluid/primitive/decomp_rule/decomp_vjp/details.h
Original file line number Diff line number Diff line change
Expand Up @@ -3341,8 +3341,12 @@ void take_along_axis_grad(const Tensor& arr,
arr_cast.dtype(),
arr_cast.place());
}
auto arr_grad_tmp =
put_along_axis<T>(zero_tensor, indices, out_grad_cast, axis);
auto arr_grad_tmp = put_along_axis<T>(zero_tensor,
indices,
out_grad_cast,
axis,
/*reduce*/ "add",
/*include_self*/ true);
set_output<T>(ConvertToOrig<T>(arr_grad_tmp, arr.dtype()), arr_grad);
}
}
Expand Down
30 changes: 30 additions & 0 deletions test/legacy_test/test_take_along_axis_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,36 @@ def init_data(self):
self.axis_type = "int64"


class TestTakeAlongAxisDuplicatedIndices(TestTakeAlongAxisOp):
def init_data(self):
self.dtype = np.float32
self.x_type = "float32"
self.x_shape = (5, 6, 7)
self.index_type = "int64"
self.axis = 2
dim_size = self.x_shape[self.axis]
self.index = (
np.asarray([-dim_size, -dim_size, dim_size - 1, dim_size - 1, 0])
.astype(self.index_type)
.reshape([5, 1, 1])
)
self.axis_type = "int64"

def test_check_output(self):
self.check_output(
check_cinn=self.check_cinn, check_pir=True, check_prim_pir=True
)

def test_check_grad(self):
self.check_grad(
['Input'],
'Result',
check_cinn=self.check_cinn,
check_pir=True,
check_prim_pir=True,
)


class TestTakeAlongAxisFP16Op(TestTakeAlongAxisOp):
def init_data(self):
self.dtype = np.float16
Expand Down

0 comments on commit c7d49ec

Please sign in to comment.