Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions benchmark/test_unary_pointwise_perf.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ def get_tflops(self, op, *args, **kwargs):

forward_operations = [
("abs", torch.abs, FLOAT_DTYPES),
("acosh", torch.acosh, FLOAT_DTYPES),
("angle", torch.angle, COMPLEX_DTYPES + [torch.float32] + INT_DTYPES + BOOL_DTYPES),
("erf", torch.erf, FLOAT_DTYPES),
("exp", torch.exp, FLOAT_DTYPES),
Expand Down
1 change: 1 addition & 0 deletions src/flag_gems/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ def enable(
("_weight_norm_interface_backward", weight_norm_interface_backward),
("abs", abs),
("abs_", abs_),
("acosh", acosh),
("add.Tensor", add),
("add_.Tensor", add_),
("addcdiv", addcdiv),
Expand Down
2 changes: 2 additions & 0 deletions src/flag_gems/ops/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from flag_gems.ops.abs import abs, abs_
from flag_gems.ops.acosh import acosh
from flag_gems.ops.add import add, add_
from flag_gems.ops.addcdiv import addcdiv
from flag_gems.ops.addcmul import addcmul
Expand Down Expand Up @@ -207,6 +208,7 @@
"_upsample_bicubic2d_aa",
"abs",
"abs_",
"acosh",
"add",
"add_",
"addcdiv",
Expand Down
34 changes: 34 additions & 0 deletions src/flag_gems/ops/acosh.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
import logging
import triton
import triton.language as tl
import torch

from flag_gems.utils import pointwise_dynamic

logger = logging.getLogger(__name__)

@pointwise_dynamic(
promotion_method=[(0, "DEFAULT")],
)
@triton.jit
def acosh_forward_kernel(x)
return tl.log(x + tl.sqrt(x * x - 1.0))

def acosh(input: torch.Tensor, *, out: torch.Tensor = None):
"""
Returns a new tensor with the inverse hyperbolic cosine of the elements of input.
Args:
input (Tensor): the input tensor
out (Tensor, optional): the output tensor
Returns:
Tensor: the output tensor with the inverse hyperbolic cosine values
"""
result = acosh_forward_kernel(input)

if out is not None:
Copy link
Collaborator

@StrongSpoon StrongSpoon Sep 18, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I suggest referring the implementation in https://github.com/FlagOpen/FlagGems/blob/master/src/flag_gems/ops/add.py#L48. copy_ calls another kernel.

out.copy_(result)
return out

return output
14 changes: 14 additions & 0 deletions tests/test_unary_pointwise_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,20 @@ def test_accuracy_abs(shape, dtype):
gems_assert_equal(res_out, ref_out)


@pytest.mark.acosh
@pytest.mark.parametrize("shape", POINTWISE_SHAPES)
@pytest.mark.parametrize("dtype", FLOAT_DTYPES)
def test_accuracy_acosh(shape, dtype):
inp = torch.randn(shape, dtype=dtype, device=flag_gems.device).exp() + 0.5
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what if the value of inp tensor is invalid? will torch check it?


ref_inp = to_reference(inp, True)
ref_out = torch.acosh(ref_inp)
with flag_gems.use_gems():
res_out = torch.acosh(inp)

gems_assert_close(res_out, ref_out, dtype, equal_nan=True)


@pytest.mark.inplace
@pytest.mark.abs_
@pytest.mark.parametrize("shape", POINTWISE_SHAPES)
Expand Down
Loading