-
Notifications
You must be signed in to change notification settings - Fork 144
Add std #999
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Add std #999
Conversation
Summary of ChangesHello @botbigeyes, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request introduces a new standard deviation ( Highlights
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request introduces a new std
operator. The implementation is functional but has several areas for improvement. The new file src/flag_gems/ops/std.py
contains a significant amount of code duplicated from the existing var_mean
operator, which is a concern for maintainability. Additionally, the kernel for calculating variance along a specific dimension uses a numerically unstable algorithm, and the operator currently lacks support for automatic differentiation (autograd), which might limit its use in training scenarios. I've provided detailed comments and suggestions to address these issues, including a proposal to refactor std.py
to reuse the var_mean
op, a fix for the numerical stability of the kernel, and a note on adding autograd support. A minor efficiency improvement for the new test case is also suggested.
/gemini review |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request adds support for the std
operator. The implementation, however, has a critical issue in the forward pass (std
function), which returns a tuple instead of a single tensor, contrary to torch.std
's behavior. This will cause the provided tests and any dependent code to fail. The backward pass implementation (std_backward
) relies on this incorrect signature, and its registration uses a non-standard operator name, indicating potential design issues that need to be addressed.
/gemini review |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request adds support for the torch.std
operation by composing existing var_mean
and sqrt
operators for the forward pass, and providing a custom Triton kernel for the backward pass. The changes include the operator implementation, registration, and corresponding benchmark and accuracy tests. My review identifies a critical issue in the backward pass implementation (std_backward
) where the logic for handling partial reductions with keepdim=False
is flawed. This can lead to incorrect gradient calculations. I've provided a detailed explanation and a partial fix.
/gemini review |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request adds support for the torch.std
operator. The changes include the operator implementation, registration, performance benchmarks, and accuracy tests. The overall approach is sound, but there is a critical bug in the std
implementation. It incorrectly passes the unbiased
parameter to the var_mean
function, which expects a correction
parameter, leading to a TypeError
. I have provided a suggestion to fix this issue.
def std(x, dim=None, unbiased=True, keepdim=False): | ||
logger.debug("GEMS STD Forward") | ||
|
||
dim_list = ( | ||
dim if isinstance(dim, (list, tuple)) else ([dim] if dim is not None else None) | ||
) | ||
|
||
variance, _ = var_mean(x, dim=dim_list, unbiased=unbiased, keepdim=keepdim) | ||
|
||
std_dev = gems_sqrt(variance) | ||
|
||
return std_dev |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The var_mean
function is being called with an unbiased
keyword argument, but its signature expects a correction
argument. This will cause a TypeError
at runtime.
The unbiased
boolean parameter from std
needs to be converted to the integer correction
parameter for var_mean
. Specifically, unbiased=True
should correspond to correction=1
, and unbiased=False
should correspond to correction=0
.
def std(x, dim=None, unbiased=True, keepdim=False): | |
logger.debug("GEMS STD Forward") | |
dim_list = ( | |
dim if isinstance(dim, (list, tuple)) else ([dim] if dim is not None else None) | |
) | |
variance, _ = var_mean(x, dim=dim_list, unbiased=unbiased, keepdim=keepdim) | |
std_dev = gems_sqrt(variance) | |
return std_dev | |
def std(x, dim=None, unbiased=True, keepdim=False): | |
logger.debug("GEMS STD Forward") | |
dim_list = ( | |
dim if isinstance(dim, (list, tuple)) else ([dim] if dim is not None else None) | |
) | |
correction = 1 if unbiased else 0 | |
variance, _ = var_mean(x, dim=dim_list, correction=correction, keepdim=keepdim) | |
std_dev = gems_sqrt(variance) | |
return std_dev |
PR Category
Operator
Type of Change
New Feature
Description
Add std
Issue
Progress
Performance