-
Notifications
You must be signed in to change notification settings - Fork 144
Add avg_pool2d forward and backward implementation #959
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Summary of Changes
Hello @Blury233, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!
This pull request significantly expands the operator library by introducing highly optimized avg_pool2d
and max_pool2d
implementations, complete with their respective forward and backward passes. These new pooling operations are built using Triton kernels for performance and are thoroughly validated with comprehensive accuracy and performance tests. Additionally, the backward pass for the ELU activation function has been improved to support different computation modes.
Highlights
- AvgPool2d Implementation: Added forward and backward passes for
avg_pool2d
using Triton kernels, including support for various parameters likekernel_size
,stride
,padding
,ceil_mode
,count_include_pad
, anddivisor_override
. - MaxPool2d Implementation: Introduced optimized forward and backward passes for
max_pool2d
with Triton kernels, supportingkernel_size
,stride
,padding
,dilation
, andceil_mode
. - Comprehensive Testing: Integrated extensive accuracy and performance benchmarks for both
avg_pool2d
andmax_pool2d
to ensure correctness and efficiency across diverse configurations and input shapes. - ELU Backward Refinement: Enhanced the
elu_backward
function to correctly handle theis_result
parameter, improving flexibility in backward computations for the ELU activation function.
Using Gemini Code Assist
The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.
Invoking Gemini
You can request assistance from Gemini at any point in your pull request via creating an issue comment (i.e. comment on the pull request page) using either /gemini <command>
or @gemini-code-assist <command>
. Below is a summary of the supported commands.
Feature | Command | Description |
---|---|---|
Code Review | /gemini review |
Performs a code review for the current pull request in its current state. |
Pull Request Summary | /gemini summary |
Provides a summary of the current pull request in its current state. |
Comment | @gemini-code-assist | Responds in comments when explicitly tagged, both in issue comments and review comments. |
Help | /gemini help |
Displays a list of available commands. |
Customization
To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/
folder in the base of the repository. Detailed instructions can be found here.
Limitations & Feedback
Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.
You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.
Footnotes
-
Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution. ↩
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request introduces Triton implementations for avg_pool2d
and max_pool2d
, including both forward and backward passes, along with corresponding accuracy and performance tests. The PR description only mentions avg_pool2d
, but the changes correctly include both operators.
My review focuses on potential performance improvements, code duplication, and robustness. I've identified an opportunity to optimize the avg_pool2d
backward pass and suggested refactoring to reduce duplicated code in both the operator implementations and their benchmarks. I've also pointed out a minor issue with hardcoded strides.
Overall, this is a great contribution that extends the library's capabilities. Addressing the feedback will further improve the performance and maintainability of the new code.
else: | ||
# Re-compute count for the divisor when padding is not included. | ||
h_start = h_out * stride_h - padding_h | ||
w_start = w_out * stride_w - padding_w | ||
count = tl.zeros((BLOCK_H, BLOCK_W), dtype=tl.int32) | ||
for kh_count in range(0, kernel_h): | ||
for kw_count in range(0, kernel_w): | ||
h_in_for_count = h_start + kh_count * dilation_h | ||
w_in_for_count = w_start + kw_count * dilation_w | ||
is_valid = ( | ||
(h_in_for_count >= 0) | ||
& (h_in_for_count < in_h) | ||
& (w_in_for_count >= 0) | ||
& (w_in_for_count < in_w) | ||
) | ||
count += is_valid.to(tl.int32) | ||
divisor = count.to(tl.float32) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The re-computation of count
within the backward pass when COUNT_INCLUDE_PAD
is False
is inefficient. This involves nested loops that run for every element in the input gradient block, which can significantly degrade performance.
A more efficient approach would be to compute the count
tensor during the forward pass and save it for the backward pass using ctx.save_for_backward()
. The forward kernel already computes this information in count_acc
, so it could be modified to return this value. This would eliminate the need for the expensive nested loops in the backward kernel.
class AvgPool2dBenchmark(GenericBenchmark): | ||
def get_input_iter(self, cur_dtype) -> Generator: | ||
shapes_4d = [ | ||
(4, 3, 224, 224), # Typical input image size | ||
(16, 64, 56, 56), # Early ResNet layer output | ||
(32, 128, 28, 28), # Mid ResNet layer output | ||
(64, 256, 14, 14), # Later ResNet layer output | ||
(128, 512, 7, 7), # Final ResNet layer output | ||
] | ||
|
||
for shape in shapes_4d: | ||
yield from self.input_fn(shape, cur_dtype, self.device) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The get_input_iter
method is identical in AvgPool2dBenchmark
and MaxPool2dBenchmark
(lines 392-403). You can refactor this to reduce code duplication. For instance, you could create a common base class for pooling benchmarks that contains this method.
Example:
class Pool2dBenchmark(GenericBenchmark):
def get_input_iter(self, cur_dtype) -> Generator:
shapes_4d = [
(4, 3, 224, 224), # Typical input image size
(16, 64, 56, 56), # Early ResNet layer output
(32, 128, 28, 28), # Mid ResNet layer output
(64, 256, 14, 14), # Later ResNet layer output
(128, 512, 7, 7), # Final ResNet layer output
]
for shape in shapes_4d:
yield from self.input_fn(shape, cur_dtype, self.device)
class AvgPool2dBenchmark(Pool2dBenchmark):
pass
class MaxPool2dBenchmark(Pool2dBenchmark):
pass
Then you would initialize AvgPool2dBenchmark
and MaxPool2dBenchmark
in your tests as before, but they would inherit the common get_input_iter
.
def pool2d_output_size( | ||
in_size: int, | ||
kernel_size: int, | ||
stride: int, | ||
padding: int, | ||
dilation: int, | ||
ceil_mode: bool = False, | ||
) -> int: | ||
effective_kernel_size = (kernel_size - 1) * dilation + 1 | ||
numerator = in_size + 2 * padding - effective_kernel_size | ||
if ceil_mode: | ||
return (numerator + stride - 1) // stride + 1 | ||
else: | ||
return numerator // stride + 1 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We accept it now. Maybe somtime we would refactor the code the extract some utility functions for convolution, convolution_transpose, and pooling.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There may be some risk that the output size is different from the result of Aten counterpart.
The ceil_mode has 2 effects:
- It allow the window's right edge to extend bybind the padded input;
- but It does not allow the window to start from right padding region.
So follow ATen's implementation to be consistent or else in some cases the output size is larger the Aten's output by 1.
template<typename T>
inline T pooling_output_shape_pad_lr(
T inputSize, T kernelSize, T pad_l, T pad_r, T stride, T dilation,
bool ceil_mode) {
T outputSize = div_rtn<T>(
inputSize + pad_l + pad_r - dilation * (kernelSize - 1) - 1 +
(ceil_mode ? stride - 1 : 0), stride) + 1;
if (ceil_mode) {
// ensure that the last pooling starts inside the image
// needed to avoid problems in ceil mode
if ((outputSize - 1) * stride >= inputSize + pad_l) {
--outputSize;
}
}
return outputSize;
}
src/flag_gems/ops/max_pool2d.py
Outdated
def max_pool2d_output_size( | ||
in_size: int, | ||
kernel_size: int, | ||
stride: int, | ||
padding: int, | ||
dilation: int, | ||
ceil_mode: bool = False, | ||
) -> int: | ||
effective_kernel_size = (kernel_size - 1) * dilation + 1 | ||
numerator = in_size + 2 * padding - effective_kernel_size | ||
if ceil_mode: | ||
return (numerator + stride - 1) // stride + 1 | ||
else: | ||
return numerator // stride + 1 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I agree
src/flag_gems/ops/max_pool2d.py
Outdated
|
||
ctx.save_for_backward(indices) | ||
ctx.in_shape = input.shape | ||
ctx.in_strides = (in_c * in_h * in_w, in_h * in_w, in_w, 1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The in_strides
saved to the context are hardcoded assuming a contiguous tensor. While input
is made contiguous at the start of the function, it's a better practice to use input.stride()
to get the actual strides. This makes the code more robust if the contiguity assumption changes in the future.
ctx.in_strides = (in_c * in_h * in_w, in_h * in_w, in_w, 1) | |
ctx.in_strides = input.stride() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I agree.
src/flag_gems/ops/max_pool2d.py
Outdated
* triton.cdiv(out_w, meta["BLOCK_W"]), | ||
) | ||
|
||
max_pool2d_forward_kernel_optimized[grid]( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
add device guard to ensure that the kernel is launched to the correct device.
src/flag_gems/ops/max_pool2d.py
Outdated
|
||
grid = (grad_output.numel(),) | ||
if grad_output.numel() > 0: | ||
max_pool2d_backward_kernel[grid]( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In this case, a CTA computes an element in the output. Maybe this is too-small.
src/flag_gems/ops/max_pool2d.py
Outdated
if max_idx_flat != -1: | ||
grad_input_offset = max_idx_flat | ||
# Atomically add the gradient to the correct location in grad_input | ||
tl.atomic_add( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe consider rewrite it with a max-pool-like kernel, which partitions tasks w.r.t. to grad_input to avoid atomic add to accumulate results.
src/flag_gems/ops/max_pool2d.py
Outdated
) | ||
|
||
|
||
class MaxPool2d(torch.autograd.Function): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Avoid using autograd.Function
since it is not the encouraged way now.
Overwrite max_pool2d_with_indices_backward.grad_input
or max_pool2d_with_indices_backward
.
5ec214c
to
715e7c4
Compare
) | ||
|
||
|
||
class AvgPool2d(torch.autograd.Function): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I suggest not to subclass torch.autograd.Function
and add implementions for the forward and backward operator, separately.
avg_pool2d
oravg_pool2d.out
avg_pool2d_backward
andavg_pool2d_backward.grad_input
To do that, we need to check their schema in pytorch/aten/src/ATen/native/native_functions.yaml
and their signature in the implementation files to be sure that our implementation uses the same signature.
divisor_override, | ||
): | ||
logger.debug("GEMS AVG_POOL2D FORWARD") | ||
input = input.contiguous() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Add argument checking that
- stride should not be zero
- pad should not be negative
- pad is less than half of the window size
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please Keep in mind that Aten's implementation would try to preserve & propagate channel_last layout when possible, but this implementation may introduce extra transposition.
const auto memory_format = input_.suggest_memory_format();
Tensor input = input_.contiguous(memory_format);
Maybe we can leave it for further optimizations.
dilation_h, | ||
dilation_w, | ||
COUNT_INCLUDE_PAD=count_include_pad, | ||
divisor_override=divisor_override or 0, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I advice to check that divisor is not zero if it is not None.
If it is None, just pass it to the jit function. Do not use 0 as a magic number.
w_out_mask = w_valid_map & (w_out < out_w) | ||
out_mask = h_out_mask & w_out_mask | ||
|
||
if divisor_override > 0: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
divisor may also be negative.
dilation_w: tl.constexpr, | ||
# AvgPool specific parameters | ||
COUNT_INCLUDE_PAD: tl.constexpr, | ||
divisor_override: tl.constexpr, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
divisor_override
should not be a constexpr since it may be a float, which may not be critical to performance.
count += is_valid.to(tl.int32) | ||
divisor = count.to(tl.float32) | ||
|
||
divisor = tl.where(divisor == 0, 1.0, divisor) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In what case may divisor be zero?
PR Category
Operator
Type of Change
New Feature
Description
This PR adds:
Issue
Associated with Feature #894
Progress
Performance
Accuracy Test




Performance Test