Skip to content

Conversation

hucheng-star
Copy link

PR Category

Operator | OP Test | Benchmark

Type of Change

New Feature

Description

Develop forward function for acosh operator.

Issue

Progress

  • Change is properly reviewed (1 reviewer required, 2 recommended).
  • Change is responded to an issue.
  • Change is fully covered by a UT.

Performance

image image image image

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Summary of Changes

Hello @hucheng-star, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request integrates the acosh (inverse hyperbolic cosine) operator into the flag_gems library. The primary goal is to expand the library's mathematical function capabilities by providing an optimized implementation of acosh using Triton, along with corresponding tests and performance measurements.

Highlights

  • New acosh Operator Implementation: Introduces the forward function for the acosh (inverse hyperbolic cosine) operator.
  • Triton Kernel for acosh: The acosh operator's core logic is implemented using a Triton kernel for optimized performance.
  • Benchmarking and Testing: Includes updates to performance benchmarks and adds a new accuracy test for the acosh operator.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point in your pull request via creating an issue comment (i.e. comment on the pull request page) using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in issue comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request adds the acosh operator. The implementation is functional but has some areas for improvement regarding generality and simplicity. My review provides suggestions to refactor the operator using the existing pointwise_dynamic utility for better robustness and to correct an issue in the unit test's assertion.

Comment on lines 1 to 51
import logging
import triton
import triton.language as tl
import torch

from flag_gems.utils import pointwise_dynamic

logger = logging.getLogger(__name__)

@triton.jit
def acosh_forward_kernel(
input_ptr,
output_ptr,
n_elements,
BLOCK_SIZE: tl.constexpr,
)
pid = tl.program_id(axis=0)
block_start = pid * BLOCK_SIZE
offsets = block_start + tl.arange(0, BLOCK_SIZE)
mask = offsets < n_elements

# 加载输入数据
x = tl.load(input_ptr + offsets, mask=mask)

# step1: 若x < 1, 则返回 nan
one = tl.full((BLOCK_SIZE,), 1.0, dtype=tl.float32)
# step2: 计算(x^2 - 1)
x_sq_minus_one = tl.where(x >= 1, x * x - one, tl.full((BLOCK_SIZE,), float('nan'), dtype=tl.float32))
# step3: 计算sqrt(x^2 - 1)
sqrt_val = tl.sqrt(tl.maximum(x_sq_minus_one, 0.0))
# step4: 计算x + sqrt(x^2 - 1)
sum_val = x + sqrt_val
# step5: 计算ln(x + sqrt(x^2 - 1))
result = tl.log(sum_val)
# step6: 保存结果
tl.store(output_ptr + offsets, result, mask=mask)

def acosh(input: torch.Tensor):
output = torch.empty_like(input)

n_elements = input.numel()

BLOCK_SIZE = 1024

grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),)

acosh_forward_kernel[grid](
input, output, n_elements, BLOCK_SIZE=BLOCK_SIZE
)

return output No newline at end of file
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The current implementation of acosh is not generic; it only works for contiguous tensors that can be treated as 1D. The project includes a pointwise_dynamic utility which is designed to create generic pointwise operators. This utility is imported but not used.

Additionally, the Triton kernel acosh_forward_kernel is more complex than necessary and contains inefficiencies. For example:

  • The domain check for x >= 1 can be handled implicitly by tl.sqrt, which returns NaN for negative inputs.
  • tl.full((BLOCK_SIZE,), 1.0, ...) can be replaced by a scalar 1.0.
  • The comments are in Chinese, which is inconsistent with the rest of the codebase.

I suggest rewriting the file to use pointwise_dynamic for a much simpler, more robust, and more maintainable implementation that aligns with the design of other operators in this library.

Suggested change
import logging
import triton
import triton.language as tl
import torch
from flag_gems.utils import pointwise_dynamic
logger = logging.getLogger(__name__)
@triton.jit
def acosh_forward_kernel(
input_ptr,
output_ptr,
n_elements,
BLOCK_SIZE: tl.constexpr,
)
pid = tl.program_id(axis=0)
block_start = pid * BLOCK_SIZE
offsets = block_start + tl.arange(0, BLOCK_SIZE)
mask = offsets < n_elements
# 加载输入数据
x = tl.load(input_ptr + offsets, mask=mask)
# step1: 若x < 1, 则返回 nan
one = tl.full((BLOCK_SIZE,), 1.0, dtype=tl.float32)
# step2: 计算(x^2 - 1)
x_sq_minus_one = tl.where(x >= 1, x * x - one, tl.full((BLOCK_SIZE,), float('nan'), dtype=tl.float32))
# step3: 计算sqrt(x^2 - 1)
sqrt_val = tl.sqrt(tl.maximum(x_sq_minus_one, 0.0))
# step4: 计算x + sqrt(x^2 - 1)
sum_val = x + sqrt_val
# step5: 计算ln(x + sqrt(x^2 - 1))
result = tl.log(sum_val)
# step6: 保存结果
tl.store(output_ptr + offsets, result, mask=mask)
def acosh(input: torch.Tensor):
output = torch.empty_like(input)
n_elements = input.numel()
BLOCK_SIZE = 1024
grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),)
acosh_forward_kernel[grid](
input, output, n_elements, BLOCK_SIZE=BLOCK_SIZE
)
return output
import triton
import triton.language as tl
from flag_gems.utils import pointwise_dynamic
@pointwise_dynamic(
promotion_methods=[(0, "DEFAULT")],
)
@triton.jit
def acosh(x):
"""
Computes the inverse hyperbolic cosine of x.
The domain is `x >= 1`. For `x < 1`, the result is `NaN`.
The formula is `acosh(x) = log(x + sqrt(x*x - 1))`.
Triton's `sqrt` of a negative number produces `NaN`, which propagates,
so special handling for the domain is not required.
"""
return tl.log(x + tl.sqrt(x * x - 1.0))

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, gemini is right, and please delete chinese comment:)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

with flag_gems.use_gems():
res_out = torch.acosh(inp)

gems_assert_close(res_out, ref_out, dtype)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The input tensor inp is generated with torch.randn(...).exp() + 0.5, which can produce values less than 1. For these out-of-domain values, torch.acosh returns NaN. To ensure your implementation also correctly returns NaN and the test verifies this, you should enable NaN comparison in the assertion.

Suggested change
gems_assert_close(res_out, ref_out, dtype)
gems_assert_close(res_out, ref_out, dtype, equal_nan=True)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

@meinie0826
Copy link
Collaborator

great! do you retest the perf?

@huc85153-star
Copy link

great! do you retest the perf?

Yes, performance is great, but i forget to update result. I will update as soon as possible.

@huc85153-star
Copy link

update performance result

@huc85153-star
Copy link

update performance result

image image image

@huc85153-star
Copy link

image

@meinie0826
Copy link
Collaborator

meinie0826 commented Sep 11, 2025

great! Please add logger.rebug ("xxx") to this.

"""
result = acosh_forward_kernel(input)

if out is not None:
Copy link
Collaborator

@StrongSpoon StrongSpoon Sep 18, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I suggest referring the implementation in https://github.com/FlagOpen/FlagGems/blob/master/src/flag_gems/ops/add.py#L48. copy_ calls another kernel.

@pytest.mark.parametrize("shape", POINTWISE_SHAPES)
@pytest.mark.parametrize("dtype", FLOAT_DTYPES)
def test_accuracy_acosh(shape, dtype):
inp = torch.randn(shape, dtype=dtype, device=flag_gems.device).exp() + 0.5
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what if the value of inp tensor is invalid? will torch check it?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Code Contribution: 【Easy】【Operator Development】acosh

4 participants