-
Notifications
You must be signed in to change notification settings - Fork 144
acosh process #942
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
acosh process #942
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Summary of Changes
Hello @hucheng-star, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!
This pull request integrates the acosh (inverse hyperbolic cosine) operator into the flag_gems library. The primary goal is to expand the library's mathematical function capabilities by providing an optimized implementation of acosh using Triton, along with corresponding tests and performance measurements.
Highlights
- New acosh Operator Implementation: Introduces the forward function for the acosh (inverse hyperbolic cosine) operator.
- Triton Kernel for acosh: The acosh operator's core logic is implemented using a Triton kernel for optimized performance.
- Benchmarking and Testing: Includes updates to performance benchmarks and adds a new accuracy test for the acosh operator.
Using Gemini Code Assist
The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.
Invoking Gemini
You can request assistance from Gemini at any point in your pull request via creating an issue comment (i.e. comment on the pull request page) using either /gemini <command>
or @gemini-code-assist <command>
. Below is a summary of the supported commands.
Feature | Command | Description |
---|---|---|
Code Review | /gemini review |
Performs a code review for the current pull request in its current state. |
Pull Request Summary | /gemini summary |
Provides a summary of the current pull request in its current state. |
Comment | @gemini-code-assist | Responds in comments when explicitly tagged, both in issue comments and review comments. |
Help | /gemini help |
Displays a list of available commands. |
Customization
To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/
folder in the base of the repository. Detailed instructions can be found here.
Limitations & Feedback
Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.
You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.
Footnotes
-
Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution. ↩
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request adds the acosh
operator. The implementation is functional but has some areas for improvement regarding generality and simplicity. My review provides suggestions to refactor the operator using the existing pointwise_dynamic
utility for better robustness and to correct an issue in the unit test's assertion.
import logging | ||
import triton | ||
import triton.language as tl | ||
import torch | ||
|
||
from flag_gems.utils import pointwise_dynamic | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
@triton.jit | ||
def acosh_forward_kernel( | ||
input_ptr, | ||
output_ptr, | ||
n_elements, | ||
BLOCK_SIZE: tl.constexpr, | ||
) | ||
pid = tl.program_id(axis=0) | ||
block_start = pid * BLOCK_SIZE | ||
offsets = block_start + tl.arange(0, BLOCK_SIZE) | ||
mask = offsets < n_elements | ||
|
||
# 加载输入数据 | ||
x = tl.load(input_ptr + offsets, mask=mask) | ||
|
||
# step1: 若x < 1, 则返回 nan | ||
one = tl.full((BLOCK_SIZE,), 1.0, dtype=tl.float32) | ||
# step2: 计算(x^2 - 1) | ||
x_sq_minus_one = tl.where(x >= 1, x * x - one, tl.full((BLOCK_SIZE,), float('nan'), dtype=tl.float32)) | ||
# step3: 计算sqrt(x^2 - 1) | ||
sqrt_val = tl.sqrt(tl.maximum(x_sq_minus_one, 0.0)) | ||
# step4: 计算x + sqrt(x^2 - 1) | ||
sum_val = x + sqrt_val | ||
# step5: 计算ln(x + sqrt(x^2 - 1)) | ||
result = tl.log(sum_val) | ||
# step6: 保存结果 | ||
tl.store(output_ptr + offsets, result, mask=mask) | ||
|
||
def acosh(input: torch.Tensor): | ||
output = torch.empty_like(input) | ||
|
||
n_elements = input.numel() | ||
|
||
BLOCK_SIZE = 1024 | ||
|
||
grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),) | ||
|
||
acosh_forward_kernel[grid]( | ||
input, output, n_elements, BLOCK_SIZE=BLOCK_SIZE | ||
) | ||
|
||
return output No newline at end of file |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The current implementation of acosh
is not generic; it only works for contiguous tensors that can be treated as 1D. The project includes a pointwise_dynamic
utility which is designed to create generic pointwise operators. This utility is imported but not used.
Additionally, the Triton kernel acosh_forward_kernel
is more complex than necessary and contains inefficiencies. For example:
- The domain check for
x >= 1
can be handled implicitly bytl.sqrt
, which returnsNaN
for negative inputs. tl.full((BLOCK_SIZE,), 1.0, ...)
can be replaced by a scalar1.0
.- The comments are in Chinese, which is inconsistent with the rest of the codebase.
I suggest rewriting the file to use pointwise_dynamic
for a much simpler, more robust, and more maintainable implementation that aligns with the design of other operators in this library.
import logging | |
import triton | |
import triton.language as tl | |
import torch | |
from flag_gems.utils import pointwise_dynamic | |
logger = logging.getLogger(__name__) | |
@triton.jit | |
def acosh_forward_kernel( | |
input_ptr, | |
output_ptr, | |
n_elements, | |
BLOCK_SIZE: tl.constexpr, | |
) | |
pid = tl.program_id(axis=0) | |
block_start = pid * BLOCK_SIZE | |
offsets = block_start + tl.arange(0, BLOCK_SIZE) | |
mask = offsets < n_elements | |
# 加载输入数据 | |
x = tl.load(input_ptr + offsets, mask=mask) | |
# step1: 若x < 1, 则返回 nan | |
one = tl.full((BLOCK_SIZE,), 1.0, dtype=tl.float32) | |
# step2: 计算(x^2 - 1) | |
x_sq_minus_one = tl.where(x >= 1, x * x - one, tl.full((BLOCK_SIZE,), float('nan'), dtype=tl.float32)) | |
# step3: 计算sqrt(x^2 - 1) | |
sqrt_val = tl.sqrt(tl.maximum(x_sq_minus_one, 0.0)) | |
# step4: 计算x + sqrt(x^2 - 1) | |
sum_val = x + sqrt_val | |
# step5: 计算ln(x + sqrt(x^2 - 1)) | |
result = tl.log(sum_val) | |
# step6: 保存结果 | |
tl.store(output_ptr + offsets, result, mask=mask) | |
def acosh(input: torch.Tensor): | |
output = torch.empty_like(input) | |
n_elements = input.numel() | |
BLOCK_SIZE = 1024 | |
grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),) | |
acosh_forward_kernel[grid]( | |
input, output, n_elements, BLOCK_SIZE=BLOCK_SIZE | |
) | |
return output | |
import triton | |
import triton.language as tl | |
from flag_gems.utils import pointwise_dynamic | |
@pointwise_dynamic( | |
promotion_methods=[(0, "DEFAULT")], | |
) | |
@triton.jit | |
def acosh(x): | |
""" | |
Computes the inverse hyperbolic cosine of x. | |
The domain is `x >= 1`. For `x < 1`, the result is `NaN`. | |
The formula is `acosh(x) = log(x + sqrt(x*x - 1))`. | |
Triton's `sqrt` of a negative number produces `NaN`, which propagates, | |
so special handling for the domain is not required. | |
""" | |
return tl.log(x + tl.sqrt(x * x - 1.0)) | |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, gemini is right, and please delete chinese comment:)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
tests/test_unary_pointwise_ops.py
Outdated
with flag_gems.use_gems(): | ||
res_out = torch.acosh(inp) | ||
|
||
gems_assert_close(res_out, ref_out, dtype) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The input tensor inp
is generated with torch.randn(...).exp() + 0.5
, which can produce values less than 1. For these out-of-domain values, torch.acosh
returns NaN
. To ensure your implementation also correctly returns NaN
and the test verifies this, you should enable NaN comparison in the assertion.
gems_assert_close(res_out, ref_out, dtype) | |
gems_assert_close(res_out, ref_out, dtype, equal_nan=True) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
great! do you retest the perf? |
Yes, performance is great, but i forget to update result. I will update as soon as possible. |
update performance result |
great! Please add |
""" | ||
result = acosh_forward_kernel(input) | ||
|
||
if out is not None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I suggest referring the implementation in https://github.com/FlagOpen/FlagGems/blob/master/src/flag_gems/ops/add.py#L48. copy_
calls another kernel.
@pytest.mark.parametrize("shape", POINTWISE_SHAPES) | ||
@pytest.mark.parametrize("dtype", FLOAT_DTYPES) | ||
def test_accuracy_acosh(shape, dtype): | ||
inp = torch.randn(shape, dtype=dtype, device=flag_gems.device).exp() + 0.5 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what if the value of inp tensor is invalid? will torch check it?
PR Category
Operator | OP Test | Benchmark
Type of Change
New Feature
Description
Develop forward function for acosh operator.
Issue
Progress
Performance