-
Notifications
You must be signed in to change notification settings - Fork 144
Add max_pool2d forward and backward implementation #958
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Summary of Changes
Hello @Blury233, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!
This pull request integrates a new max_pool2d
operator, a fundamental component in convolutional neural networks, into the system. The implementation includes highly optimized kernels for both the forward and backward passes, crucial for efficient deep learning computations. The changes also encompass comprehensive testing and benchmarking to validate the operator's accuracy and performance.
Highlights
- New Operator: MaxPool2d: Introduces the
max_pool2d
operator with both forward and backward implementations, enabling max pooling operations within the framework. - Triton Kernels for Performance: The
max_pool2d
operator leverages optimized Triton kernels for its forward and backward passes, aiming for high performance on compatible hardware. - Comprehensive Testing: Includes new accuracy tests for the
max_pool2d
operator, covering various configurations and ensuring correctness of both forward and backward computations. Performance benchmarks have also been added and reported. - ELU Backward Refinement: The
elu_backward
implementation has been refined to support different input types (original input vs. result of forward pass), enhancing flexibility and correctness.
Using Gemini Code Assist
The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.
Invoking Gemini
You can request assistance from Gemini at any point in your pull request via creating an issue comment (i.e. comment on the pull request page) using either /gemini <command>
or @gemini-code-assist <command>
. Below is a summary of the supported commands.
Feature | Command | Description |
---|---|---|
Code Review | /gemini review |
Performs a code review for the current pull request in its current state. |
Pull Request Summary | /gemini summary |
Provides a summary of the current pull request in its current state. |
Comment | @gemini-code-assist | Responds in comments when explicitly tagged, both in issue comments and review comments. |
Help | /gemini help |
Displays a list of available commands. |
Customization
To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/
folder in the base of the repository. Detailed instructions can be found here.
Limitations & Feedback
Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.
You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.
Footnotes
-
Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution. ↩
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request introduces an implementation for max_pool2d
forward and backward passes, along with corresponding performance and accuracy tests. The implementation is well-structured, using Triton for the kernels and torch.autograd.Function
for integration. The tests cover various configurations and seem comprehensive.
I've identified a few areas for improvement in src/flag_gems/ops/max_pool2d.py
:
- A potential precision issue in the backward pass when using
float64
data type. - An opportunity to make the stride handling more robust.
- A small cleanup to remove unused parameters in the backward kernel.
The pull request also includes a refactoring of the elu
backward pass, which is a good optimization but is not mentioned in the description. It would be good to update the description to reflect all changes.
Overall, this is a solid contribution. Addressing the suggested points will enhance the robustness and clarity of the new max_pool2d
operator.
grad_input = torch.zeros( | ||
in_shape, device=grad_output.device, dtype=torch.float32 | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The grad_input
tensor is hardcoded to torch.float32
for accumulation, which is a good practice for float16
and bfloat16
to maintain precision. However, if grad_output
is torch.float64
, this will cause precision loss as the atomic_add
in the kernel will accumulate float64
gradients into a float32
tensor. To support float64
correctly, the accumulation tensor grad_input
should also be torch.float64
.
grad_input_dtype = torch.float32
if original_dtype == torch.float64:
grad_input_dtype = torch.float64
grad_input = torch.zeros(
in_shape, device=grad_output.device, dtype=grad_input_dtype
)
def max_pool2d_backward_kernel( | ||
grad_output_ptr, | ||
indices_ptr, | ||
grad_input_ptr, | ||
# Input tensor strides | ||
in_stride_n, | ||
in_stride_c, | ||
in_stride_h, | ||
in_stride_w, | ||
# Shape info | ||
in_c, | ||
in_h, | ||
in_w, | ||
out_h, | ||
out_w, | ||
): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The parameters in_stride_h
and in_stride_w
are unused in this kernel. Removing them would make the code cleaner. Remember to also update the call site in the backward
method.
def max_pool2d_backward_kernel(
grad_output_ptr,
indices_ptr,
grad_input_ptr,
# Input tensor strides
in_stride_n,
in_stride_c,
# Shape info
in_c,
in_h,
in_w,
out_h,
out_w,
):
|
||
ctx.save_for_backward(indices) | ||
ctx.in_shape = input.shape | ||
ctx.in_strides = (in_c * in_h * in_w, in_h * in_w, in_w, 1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
max_pool2d_backward_kernel[grid]( | ||
grad_output, | ||
indices, | ||
grad_input, | ||
in_strides[0], | ||
in_strides[1], | ||
in_strides[2], | ||
in_strides[3], | ||
in_c, | ||
in_h, | ||
in_w, | ||
out_h, | ||
out_w, | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The arguments in_strides[2]
and in_strides[3]
correspond to in_stride_h
and in_stride_w
which are unused in max_pool2d_backward_kernel
. They should be removed from this call to match the updated kernel signature.
max_pool2d_backward_kernel[grid](
grad_output,
indices,
grad_input,
in_strides[0],
in_strides[1],
in_c,
in_h,
in_w,
out_h,
out_w,
)
else: | ||
return elu_backward_kernel_with_self( | ||
grad_output, alpha, scale, input_scale, self_or_result | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This file appears to have conflicts with other PR. Please update your code.
PR Category
Operator
Type of Change
New Feature
Description
This PR adds:
Issue
Progress
Performance
Accuracy Test




Performance Test