Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
aquagull committed Dec 21, 2024
1 parent 2b25a77 commit c05e065
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 3 deletions.
2 changes: 1 addition & 1 deletion paddle/phi/kernels/funcs/broadcast_function.h
Original file line number Diff line number Diff line change
Expand Up @@ -806,7 +806,7 @@ void BroadcastKernel(const KPDevice &ctx,
}
ctx.template Alloc<OutT>((*outs)[i]);
}
if ((*outs).size() == 1 && (*outs)[0]->numel() == 0) {
if ((*outs)[0]->numel() == 0) {
return;
}
int max_rank = 0;
Expand Down
5 changes: 3 additions & 2 deletions paddle/phi/kernels/funcs/common_shape.h
Original file line number Diff line number Diff line change
Expand Up @@ -69,8 +69,9 @@ inline void GetBroadcastDimsArrays(const DDim &x_dims,
}
for (int i = 0; i < max_dim; ++i) {
PADDLE_ENFORCE_EQ(
x_dims_array[i] == y_dims_array[i] || x_dims_array[i] <= 1 ||
y_dims_array[i] <= 1,
x_dims_array[i] == y_dims_array[i] ||
x_dims_array[i] <= 1 && x_dims_array[i] != 0 ||
y_dims_array[i] <= 1 && y_dims_array[i] != 0,
true,
common::errors::InvalidArgument(
"Broadcast dimension mismatch. Operands could "
Expand Down
17 changes: 17 additions & 0 deletions test/legacy_test/test_broadcast_shape.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,23 @@ def test_error(self):
ValueError, paddle.broadcast_shape, [2, 1, 3], [3, 3, 1]
)

def test_zero_size_dim(self):
test_cases = [
([0], [], [0]),
([1], [0], [0]),
([2, -1], [0], [2, 0]),
([0, 3], [3], [0, 3]),
([0, 1, 3], [0, 1, 0, 3], [0, 0, 0, 3]),
([0, 1, 3], [0, 1, 1, 5, 3], [0, 1, 0, 5, 3]),
]

for shape1, shape2, expected in test_cases:
result = paddle.broadcast_shape(shape1, shape2)
self.assertEqual(result, expected)

def test_zero_size_error(self):
self.assertRaises(ValueError, paddle.broadcast_shape, [0], [0, 2])


if __name__ == "__main__":
unittest.main()

0 comments on commit c05e065

Please sign in to comment.