Skip to content

Conversation

botbigeyes
Copy link
Contributor

@botbigeyes botbigeyes commented Oct 13, 2025

PR Category
Operator

Type of Change
New Feature

Description
Add std

Issue

Progress

  • Change is properly reviewed (1 reviewer required, 2 recommended).
  • Change is responded to an issue.
  • Change is fully covered by a UT.

Performance

Operator: std  Performance Test (dtype=torch.float16, mode=kernel,level=comprehensive)
Status       Torch Latency (ms)    Gems Latency (ms)         Gems Speedup          Torch GBPS            Gems GBPS           Size Detail
-----------------------------------------------------------------------------------------------------------------------------------------
SUCCESS               0.016480            0.017760               0.928             254.509             236.166          [torch.Size([1048576])]
SUCCESS               0.007616            0.007968               0.956               2.151               2.056          [torch.Size([64, 64]), 1]
SUCCESS               0.045376            0.045408               0.999            1478.951            1477.908          [torch.Size([4096, 4096]), 1]
SUCCESS               0.059424            0.059360               1.001            1129.323            1130.540          [torch.Size([64, 512, 512]), 1]
SUCCESS               0.955008            0.959136               0.996            4497.310            4477.954          [torch.Size([1024, 1024, 1024]), 1]
SUCCESS               0.016832            0.016704               1.008             249.430             251.341          [torch.Size([1049600])]
SUCCESS               1.180512            1.180832               1.000            3638.224            3637.238          [torch.Size([1073741824])]
SUCCESS               0.008288            0.008768               0.945               0.494               0.467          [torch.Size([1024, 1]), 1]
SUCCESS               0.007520            0.007808               0.963               8.715               8.393          [torch.Size([1024, 16]), 1]
SUCCESS               0.009632            0.009440               1.020             108.864             111.078          [torch.Size([1024, 256]), 1]
SUCCESS               0.037600            0.037600               1.000             446.203             446.203          [torch.Size([1024, 4096]), 1]
SUCCESS               0.092736            0.093088               0.996            2894.620            2883.674          [torch.Size([1024, 65536]), 1]
SUCCESS               1.163584            1.163808               1.000            3691.154            3690.443          [torch.Size([1024, 1048576]), 1]
SUCCESS               0.008384            0.008096               1.036               1.954               2.024          [torch.Size([64, 1, 64]), 1]
SUCCESS               0.012832            0.012576               1.020              20.429              20.845          [torch.Size([64, 16, 64]), 1]
SUCCESS               0.030112            0.030176               0.998             139.290             138.995          [torch.Size([64, 256, 64]), 1]
SUCCESS               0.091504            0.091328               1.002             733.398             734.811          [torch.Size([64, 4096, 64]), 1]
Operator: std  Performance Test (dtype=torch.float32, mode=kernel,level=comprehensive)
Status       Torch Latency (ms)    Gems Latency (ms)         Gems Speedup          Torch GBPS            Gems GBPS           Size Detail
-----------------------------------------------------------------------------------------------------------------------------------------
SUCCESS               0.016928            0.016960               0.998             495.546             494.611          [torch.Size([1048576])]
SUCCESS               0.008224            0.007936               1.036               3.984               4.129          [torch.Size([64, 64]), 1]
SUCCESS               0.050704            0.051040               0.993            2647.084            2629.658          [torch.Size([4096, 4096]), 1]
SUCCESS               0.062656            0.062976               0.995            2142.137            2131.252          [torch.Size([64, 512, 512]), 1]
SUCCESS               1.387840            1.388224               1.000            6189.427            6187.715          [torch.Size([1024, 1024, 1024]), 1]
SUCCESS               0.017440            0.017504               0.996             481.468             479.708          [torch.Size([1049600])]
SUCCESS               1.502048            1.501440               1.000            5718.815            5721.131          [torch.Size([1073741824])]
SUCCESS               0.008128            0.008448               0.962               1.008               0.970          [torch.Size([1024, 1]), 1]
SUCCESS               0.007744            0.007520               1.030              16.926              17.430          [torch.Size([1024, 16]), 1]
SUCCESS               0.009888            0.009856               1.003             212.091             212.779          [torch.Size([1024, 256]), 1]
SUCCESS               0.040032            0.040064               0.999             838.190             837.521          [torch.Size([1024, 4096]), 1]
SUCCESS               0.112224            0.112064               1.001            4783.923            4790.753          [torch.Size([1024, 65536]), 1]
SUCCESS               1.502528            1.502928               1.000            5716.988            5715.467          [torch.Size([1024, 1048576]), 1]
SUCCESS               0.008352            0.008352               1.000               3.923               3.923          [torch.Size([64, 1, 64]), 1]
SUCCESS               0.012544            0.012768               0.982              41.796              41.063          [torch.Size([64, 16, 64]), 1]
SUCCESS               0.031136            0.031168               0.999             269.418             269.142          [torch.Size([64, 256, 64]), 1]
SUCCESS               0.100448            0.100224               1.002            1336.191            1339.177          [torch.Size([64, 4096, 64]), 1]
Operator: std  Performance Test (dtype=torch.bfloat16, mode=kernel,level=comprehensive)
Status       Torch Latency (ms)    Gems Latency (ms)         Gems Speedup          Torch GBPS            Gems GBPS           Size Detail
-----------------------------------------------------------------------------------------------------------------------------------------
SUCCESS               0.016512            0.016512               1.000             254.016             254.016          [torch.Size([1048576])]
SUCCESS               0.007776            0.007552               1.030               2.107               2.169          [torch.Size([64, 64]), 1]
SUCCESS               0.045024            0.045376               0.992            1490.513            1478.951          [torch.Size([4096, 4096]), 1]
SUCCESS               0.059200            0.059456               0.996            1133.596            1128.715          [torch.Size([64, 512, 512]), 1]
SUCCESS               0.959456            0.959920               1.000            4476.461            4474.297          [torch.Size([1024, 1024, 1024]), 1]
SUCCESS               0.017120            0.017088               1.002             245.234             245.693          [torch.Size([1049600])]
SUCCESS               1.191824            1.191872               1.000            3603.693            3603.547          [torch.Size([1073741824])]
SUCCESS               0.008288            0.008320               0.996               0.494               0.492          [torch.Size([1024, 1]), 1]
SUCCESS               0.007488            0.007520               0.996               8.752               8.715          [torch.Size([1024, 16]), 1]
SUCCESS               0.009600            0.009632               0.997             109.227             108.864          [torch.Size([1024, 256]), 1]
SUCCESS               0.037568            0.037824               0.993             446.583             443.560          [torch.Size([1024, 4096]), 1]
SUCCESS               0.093856            0.094016               0.998            2860.078            2855.210          [torch.Size([1024, 65536]), 1]
SUCCESS               1.179296            1.178880               1.000            3641.976            3643.261          [torch.Size([1024, 1048576]), 1]
SUCCESS               0.008192            0.008192               1.000               2.000               2.000          [torch.Size([64, 1, 64]), 1]
SUCCESS               0.012640            0.012640               1.000              20.739              20.739          [torch.Size([64, 16, 64]), 1]
SUCCESS               0.030080            0.030112               0.999             139.438             139.290          [torch.Size([64, 256, 64]), 1]
SUCCESS               0.091296            0.091296               1.000             735.069             735.069          [torch.Size([64, 4096, 64]), 1]

Copy link
Contributor

Summary of Changes

Hello @botbigeyes, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request introduces a new standard deviation (std) operator to the flag_gems library, implemented with high-performance Triton kernels. The operator supports flexible usage, including calculations along specified dimensions or over the entire tensor, with options for unbiased estimation and preserving dimensions. The changes include the core implementation, integration into the library's API, comprehensive unit tests to ensure accuracy against PyTorch, and inclusion in the performance benchmarking framework to monitor its efficiency.

Highlights

  • New std Operator Implementation: A new standard deviation (std) operator has been implemented using Triton kernels, supporting both full-tensor and dimension-wise calculations.
  • Welford's Algorithm for Variance: The std operator internally uses Welford's algorithm for numerically stable computation of variance, which is then used to derive the standard deviation.
  • Comprehensive Accuracy Testing: Extensive accuracy tests have been added to verify the correctness of the std operator against PyTorch's native implementation across various shapes, dimensions, unbiased settings, keepdim settings, and data types.
  • Performance Benchmarking Integration: The std operator has been integrated into the performance benchmarking suite, with initial results showing competitive latency and GBPS compared to Torch for various tensor sizes and data types.
  • Library Integration: The new std operator is now exposed and enabled within the flag_gems library, making it available for use when flag_gems is active.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a new std operator. The implementation is functional but has several areas for improvement. The new file src/flag_gems/ops/std.py contains a significant amount of code duplicated from the existing var_mean operator, which is a concern for maintainability. Additionally, the kernel for calculating variance along a specific dimension uses a numerically unstable algorithm, and the operator currently lacks support for automatic differentiation (autograd), which might limit its use in training scenarios. I've provided detailed comments and suggestions to address these issues, including a proposal to refactor std.py to reuse the var_mean op, a fix for the numerical stability of the kernel, and a note on adding autograd support. A minor efficiency improvement for the new test case is also suggested.

@0x45f
Copy link
Collaborator

0x45f commented Oct 13, 2025

/gemini review

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request adds support for the std operator. The implementation, however, has a critical issue in the forward pass (std function), which returns a tuple instead of a single tensor, contrary to torch.std's behavior. This will cause the provided tests and any dependent code to fail. The backward pass implementation (std_backward) relies on this incorrect signature, and its registration uses a non-standard operator name, indicating potential design issues that need to be addressed.

@0x45f
Copy link
Collaborator

0x45f commented Oct 14, 2025

/gemini review

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request adds support for the torch.std operation by composing existing var_mean and sqrt operators for the forward pass, and providing a custom Triton kernel for the backward pass. The changes include the operator implementation, registration, and corresponding benchmark and accuracy tests. My review identifies a critical issue in the backward pass implementation (std_backward) where the logic for handling partial reductions with keepdim=False is flawed. This can lead to incorrect gradient calculations. I've provided a detailed explanation and a partial fix.

@0x45f
Copy link
Collaborator

0x45f commented Oct 15, 2025

/gemini review

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request adds support for the torch.std operator. The changes include the operator implementation, registration, performance benchmarks, and accuracy tests. The overall approach is sound, but there is a critical bug in the std implementation. It incorrectly passes the unbiased parameter to the var_mean function, which expects a correction parameter, leading to a TypeError. I have provided a suggestion to fix this issue.

Comment on lines 9 to 20
def std(x, dim=None, unbiased=True, keepdim=False):
logger.debug("GEMS STD Forward")

dim_list = (
dim if isinstance(dim, (list, tuple)) else ([dim] if dim is not None else None)
)

variance, _ = var_mean(x, dim=dim_list, unbiased=unbiased, keepdim=keepdim)

std_dev = gems_sqrt(variance)

return std_dev
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The var_mean function is being called with an unbiased keyword argument, but its signature expects a correction argument. This will cause a TypeError at runtime.

The unbiased boolean parameter from std needs to be converted to the integer correction parameter for var_mean. Specifically, unbiased=True should correspond to correction=1, and unbiased=False should correspond to correction=0.

Suggested change
def std(x, dim=None, unbiased=True, keepdim=False):
logger.debug("GEMS STD Forward")
dim_list = (
dim if isinstance(dim, (list, tuple)) else ([dim] if dim is not None else None)
)
variance, _ = var_mean(x, dim=dim_list, unbiased=unbiased, keepdim=keepdim)
std_dev = gems_sqrt(variance)
return std_dev
def std(x, dim=None, unbiased=True, keepdim=False):
logger.debug("GEMS STD Forward")
dim_list = (
dim if isinstance(dim, (list, tuple)) else ([dim] if dim is not None else None)
)
correction = 1 if unbiased else 0
variance, _ = var_mean(x, dim=dim_list, correction=correction, keepdim=keepdim)
std_dev = gems_sqrt(variance)
return std_dev

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants