Skip to content

Conversation

Blury233
Copy link
Contributor

PR Category

Operator

Type of Change

New Feature

Description

This PR adds:

  • avg_pool2d forward
  • avg_pool2d backward

Issue

Associated with Feature #894

Progress

  • Change is properly reviewed (1 reviewer required, 2 recommended).
  • Change is responded to an issue.
  • Change is fully covered by a UT.

Performance

Accuracy Test
acc
Performance Test
perf1
perf2
perf3

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Summary of Changes

Hello @Blury233, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request significantly expands the operator library by introducing highly optimized avg_pool2d and max_pool2d implementations, complete with their respective forward and backward passes. These new pooling operations are built using Triton kernels for performance and are thoroughly validated with comprehensive accuracy and performance tests. Additionally, the backward pass for the ELU activation function has been improved to support different computation modes.

Highlights

  • AvgPool2d Implementation: Added forward and backward passes for avg_pool2d using Triton kernels, including support for various parameters like kernel_size, stride, padding, ceil_mode, count_include_pad, and divisor_override.
  • MaxPool2d Implementation: Introduced optimized forward and backward passes for max_pool2d with Triton kernels, supporting kernel_size, stride, padding, dilation, and ceil_mode.
  • Comprehensive Testing: Integrated extensive accuracy and performance benchmarks for both avg_pool2d and max_pool2d to ensure correctness and efficiency across diverse configurations and input shapes.
  • ELU Backward Refinement: Enhanced the elu_backward function to correctly handle the is_result parameter, improving flexibility in backward computations for the ELU activation function.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point in your pull request via creating an issue comment (i.e. comment on the pull request page) using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in issue comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces Triton implementations for avg_pool2d and max_pool2d, including both forward and backward passes, along with corresponding accuracy and performance tests. The PR description only mentions avg_pool2d, but the changes correctly include both operators.

My review focuses on potential performance improvements, code duplication, and robustness. I've identified an opportunity to optimize the avg_pool2d backward pass and suggested refactoring to reduce duplicated code in both the operator implementations and their benchmarks. I've also pointed out a minor issue with hardcoded strides.

Overall, this is a great contribution that extends the library's capabilities. Addressing the feedback will further improve the performance and maintainability of the new code.

Comment on lines +215 to +231
else:
# Re-compute count for the divisor when padding is not included.
h_start = h_out * stride_h - padding_h
w_start = w_out * stride_w - padding_w
count = tl.zeros((BLOCK_H, BLOCK_W), dtype=tl.int32)
for kh_count in range(0, kernel_h):
for kw_count in range(0, kernel_w):
h_in_for_count = h_start + kh_count * dilation_h
w_in_for_count = w_start + kw_count * dilation_w
is_valid = (
(h_in_for_count >= 0)
& (h_in_for_count < in_h)
& (w_in_for_count >= 0)
& (w_in_for_count < in_w)
)
count += is_valid.to(tl.int32)
divisor = count.to(tl.float32)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The re-computation of count within the backward pass when COUNT_INCLUDE_PAD is False is inefficient. This involves nested loops that run for every element in the input gradient block, which can significantly degrade performance.

A more efficient approach would be to compute the count tensor during the forward pass and save it for the backward pass using ctx.save_for_backward(). The forward kernel already computes this information in count_acc, so it could be modified to return this value. This would eliminate the need for the expensive nested loops in the backward kernel.

Comment on lines +316 to +295
class AvgPool2dBenchmark(GenericBenchmark):
def get_input_iter(self, cur_dtype) -> Generator:
shapes_4d = [
(4, 3, 224, 224), # Typical input image size
(16, 64, 56, 56), # Early ResNet layer output
(32, 128, 28, 28), # Mid ResNet layer output
(64, 256, 14, 14), # Later ResNet layer output
(128, 512, 7, 7), # Final ResNet layer output
]

for shape in shapes_4d:
yield from self.input_fn(shape, cur_dtype, self.device)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The get_input_iter method is identical in AvgPool2dBenchmark and MaxPool2dBenchmark (lines 392-403). You can refactor this to reduce code duplication. For instance, you could create a common base class for pooling benchmarks that contains this method.

Example:

class Pool2dBenchmark(GenericBenchmark):
    def get_input_iter(self, cur_dtype) -> Generator:
        shapes_4d = [
            (4, 3, 224, 224),  # Typical input image size
            (16, 64, 56, 56),  # Early ResNet layer output
            (32, 128, 28, 28),  # Mid ResNet layer output
            (64, 256, 14, 14),  # Later ResNet layer output
            (128, 512, 7, 7),  # Final ResNet layer output
        ]

        for shape in shapes_4d:
            yield from self.input_fn(shape, cur_dtype, self.device)

class AvgPool2dBenchmark(Pool2dBenchmark):
    pass

class MaxPool2dBenchmark(Pool2dBenchmark):
    pass

Then you would initialize AvgPool2dBenchmark and MaxPool2dBenchmark in your tests as before, but they would inherit the common get_input_iter.

Comment on lines +12 to +25
def pool2d_output_size(
in_size: int,
kernel_size: int,
stride: int,
padding: int,
dilation: int,
ceil_mode: bool = False,
) -> int:
effective_kernel_size = (kernel_size - 1) * dilation + 1
numerator = in_size + 2 * padding - effective_kernel_size
if ceil_mode:
return (numerator + stride - 1) // stride + 1
else:
return numerator // stride + 1
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The helper function pool2d_output_size is also defined in src/flag_gems/ops/max_pool2d.py. To avoid code duplication and improve maintainability, consider moving this function to a shared utility module (e.g., in flag_gems/utils/) and importing it where needed.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We accept it now. Maybe somtime we would refactor the code the extract some utility functions for convolution, convolution_transpose, and pooling.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There may be some risk that the output size is different from the result of Aten counterpart.

The ceil_mode has 2 effects:

  1. It allow the window's right edge to extend bybind the padded input;
  2. but It does not allow the window to start from right padding region.

So follow ATen's implementation to be consistent or else in some cases the output size is larger the Aten's output by 1.

template<typename T>
inline T pooling_output_shape_pad_lr(
        T inputSize, T kernelSize, T pad_l, T pad_r, T stride, T dilation,
        bool ceil_mode) {
    T outputSize = div_rtn<T>(
        inputSize + pad_l + pad_r - dilation * (kernelSize - 1) - 1 +
        (ceil_mode ? stride - 1 : 0), stride) + 1;
    if (ceil_mode) {
        // ensure that the last pooling starts inside the image
        // needed to avoid problems in ceil mode
        if ((outputSize - 1) * stride >= inputSize + pad_l) {
          --outputSize;
        }
    }
    return outputSize;
}

Comment on lines 13 to 26
def max_pool2d_output_size(
in_size: int,
kernel_size: int,
stride: int,
padding: int,
dilation: int,
ceil_mode: bool = False,
) -> int:
effective_kernel_size = (kernel_size - 1) * dilation + 1
numerator = in_size + 2 * padding - effective_kernel_size
if ceil_mode:
return (numerator + stride - 1) // stride + 1
else:
return numerator // stride + 1
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This function max_pool2d_output_size is a duplicate of pool2d_output_size in src/flag_gems/ops/avg_pool2d.py. To avoid code duplication, this logic should be extracted into a single, shared utility function. I'd also recommend using a generic name like pool2d_output_size for it.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree


ctx.save_for_backward(indices)
ctx.in_shape = input.shape
ctx.in_strides = (in_c * in_h * in_w, in_h * in_w, in_w, 1)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The in_strides saved to the context are hardcoded assuming a contiguous tensor. While input is made contiguous at the start of the function, it's a better practice to use input.stride() to get the actual strides. This makes the code more robust if the contiguity assumption changes in the future.

Suggested change
ctx.in_strides = (in_c * in_h * in_w, in_h * in_w, in_w, 1)
ctx.in_strides = input.stride()

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree.

* triton.cdiv(out_w, meta["BLOCK_W"]),
)

max_pool2d_forward_kernel_optimized[grid](
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add device guard to ensure that the kernel is launched to the correct device.


grid = (grad_output.numel(),)
if grad_output.numel() > 0:
max_pool2d_backward_kernel[grid](
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In this case, a CTA computes an element in the output. Maybe this is too-small.

if max_idx_flat != -1:
grad_input_offset = max_idx_flat
# Atomically add the gradient to the correct location in grad_input
tl.atomic_add(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe consider rewrite it with a max-pool-like kernel, which partitions tasks w.r.t. to grad_input to avoid atomic add to accumulate results.

)


class MaxPool2d(torch.autograd.Function):
Copy link
Collaborator

@iclementine iclementine Sep 12, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Avoid using autograd.Function since it is not the encouraged way now.

Overwrite max_pool2d_with_indices_backward.grad_input or max_pool2d_with_indices_backward.

@Blury233 Blury233 changed the title Add max_pool2d forward and backward implementationFeat/add avg pool2d Add avg_pool2d forward and backward implementation Sep 12, 2025
)


class AvgPool2d(torch.autograd.Function):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I suggest not to subclass torch.autograd.Function and add implementions for the forward and backward operator, separately.

  • avg_pool2d or avg_pool2d.out
  • avg_pool2d_backward and avg_pool2d_backward.grad_input

To do that, we need to check their schema in pytorch/aten/src/ATen/native/native_functions.yaml and their signature in the implementation files to be sure that our implementation uses the same signature.

divisor_override,
):
logger.debug("GEMS AVG_POOL2D FORWARD")
input = input.contiguous()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add argument checking that

  • stride should not be zero
  • pad should not be negative
  • pad is less than half of the window size

Copy link
Collaborator

@iclementine iclementine Sep 19, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please Keep in mind that Aten's implementation would try to preserve & propagate channel_last layout when possible, but this implementation may introduce extra transposition.

  const auto memory_format = input_.suggest_memory_format();

  Tensor input = input_.contiguous(memory_format);

Maybe we can leave it for further optimizations.

dilation_h,
dilation_w,
COUNT_INCLUDE_PAD=count_include_pad,
divisor_override=divisor_override or 0,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I advice to check that divisor is not zero if it is not None.
If it is None, just pass it to the jit function. Do not use 0 as a magic number.

w_out_mask = w_valid_map & (w_out < out_w)
out_mask = h_out_mask & w_out_mask

if divisor_override > 0:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

divisor may also be negative.

dilation_w: tl.constexpr,
# AvgPool specific parameters
COUNT_INCLUDE_PAD: tl.constexpr,
divisor_override: tl.constexpr,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

divisor_override should not be a constexpr since it may be a float, which may not be critical to performance.

count += is_valid.to(tl.int32)
divisor = count.to(tl.float32)

divisor = tl.where(divisor == 0, 1.0, divisor)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In what case may divisor be zero?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants