Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions kernels/portable/cpu/op_pixel_shuffle.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -84,8 +84,12 @@ Tensor& pixel_shuffle_out(
ET_KERNEL_CHECK(ctx, tensor_is_default_dim_order(in), InvalidArgument, out);
Tensor::SizesType expected_out_size[kTensorDimensionLimit];
size_t expected_out_dim = 0;
get_pixel_shuffle_out_target_size(
in, upscale_factor, expected_out_size, &expected_out_dim);
ET_KERNEL_CHECK(
ctx,
get_pixel_shuffle_out_target_size(
in, upscale_factor, expected_out_size, &expected_out_dim),
InvalidArgument,
out);

// Make sure the output tensor is the right size.
ET_KERNEL_CHECK(
Expand Down
8 changes: 7 additions & 1 deletion kernels/portable/cpu/util/copy_ops_util.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -346,11 +346,15 @@ bool check_pixel_unshuffle_args(
return true;
}

void get_pixel_shuffle_out_target_size(
bool get_pixel_shuffle_out_target_size(
const Tensor& in,
int64_t upscale_factor,
executorch::aten::SizesType* out_sizes,
size_t* out_ndim) {
// Prevent signed integer overflow when computing upscale_factor ^ 2.
ET_CHECK_OR_RETURN_FALSE(
upscale_factor < 32768, "Upscale factor must be less than 32768.");

*out_ndim = in.dim();
const executorch::aten::SizesType casted_upscale_factor = upscale_factor;

Expand All @@ -366,6 +370,8 @@ void get_pixel_shuffle_out_target_size(
out_sizes[i] = in.size(i) * casted_upscale_factor;
i++;
out_sizes[i] = in.size(i) * casted_upscale_factor;

return true;
}

void get_pixel_unshuffle_out_target_size(
Expand Down
2 changes: 1 addition & 1 deletion kernels/portable/cpu/util/copy_ops_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ bool check_pixel_shuffle_args(
int64_t upscale_factor,
Tensor& out);

void get_pixel_shuffle_out_target_size(
bool get_pixel_shuffle_out_target_size(
const Tensor& in,
int64_t upscale_factor,
executorch::aten::SizesType* out_sizes,
Expand Down
Loading