Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions benchmark/test_reduction_perf.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,6 +258,22 @@ def dot_input_fn(shape, dtype, device):
bench.run()


@pytest.mark.trace
def test_perf_trace():
def trace_input_fn(shape, dtype, device):
inp = generate_tensor_input(shape, dtype=dtype, device=device)
yield inp,

bench = GenericBenchmark2DOnly(
input_fn=trace_input_fn,
op_name="trace",
torch_op=torch.trace,
dtypes=FLOAT_DTYPES + INT_DTYPES,
)

bench.run()


class quantileBenchmark(GenericBenchmark):
def set_more_shapes(self):
more_shapes_1d = [(4,), (1024,), (65535)]
Expand Down
1 change: 1 addition & 0 deletions src/flag_gems/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -292,6 +292,7 @@ def enable(
("tile", tile),
("to.dtype", to_dtype),
("topk", topk),
("trace", trace),
("triu", triu),
("true_divide.Scalar", true_divide),
("true_divide.Tensor", true_divide),
Expand Down
2 changes: 2 additions & 0 deletions src/flag_gems/ops/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,7 @@
from flag_gems.ops.tile import tile
from flag_gems.ops.to import to_dtype
from flag_gems.ops.topk import topk
from flag_gems.ops.trace import trace
from flag_gems.ops.triu import triu
from flag_gems.ops.uniform import uniform_
from flag_gems.ops.unique import _unique2
Expand Down Expand Up @@ -441,6 +442,7 @@
"tile",
"to_dtype",
"topk",
"trace",
"triu",
"true_divide",
"true_divide_",
Expand Down
95 changes: 95 additions & 0 deletions src/flag_gems/ops/trace.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
import logging

import torch
import triton
import triton.language as tl

from flag_gems.runtime import torch_device_fn
from flag_gems.utils import libentry

logger = logging.getLogger(__name__)


@libentry()
@triton.jit
def trace_kernel(
inp_ptr,
out_ptr,
num_diag,
stride0,
stride1,
BLOCK_SIZE: tl.constexpr,
):
inp_dtype = inp_ptr.type.element_ty
if inp_dtype.is_int():
acc_dtype = tl.int64
other_val = 0
elif inp_dtype == tl.float64:
acc_dtype = tl.float64
other_val = 0.0
else:
acc_dtype = tl.float32
other_val = 0.0

acc = tl.zeros((BLOCK_SIZE,), dtype=acc_dtype)

diag_stride = stride0 + stride1

for i in range(0, tl.cdiv(num_diag, BLOCK_SIZE)):
block_start = i * BLOCK_SIZE
current_indices = block_start + tl.arange(0, BLOCK_SIZE)

mask = current_indices < num_diag

ptr_offsets = current_indices * diag_stride
current_ptrs = inp_ptr + ptr_offsets

vals = tl.load(current_ptrs, mask=mask, other=other_val)

acc += vals.to(acc_dtype)

final_sum = tl.sum(acc, axis=0)
tl.store(out_ptr, final_sum.to(out_ptr.type.element_ty))


def trace(self):
logger.debug("GEMS TRACE")

if self.ndim != 2:
raise RuntimeError(
f"trace: expected a 2D tensor, but got a {self.ndim}D tensor"
)

M, N = self.shape
stride0, stride1 = self.stride()
num_diag = min(M, N)
if num_diag == 0:
if self.dtype.is_floating_point:
return torch.tensor(0.0, dtype=self.dtype, device=self.device)
else:
return torch.tensor(0, dtype=torch.int64, device=self.device)

if self.dtype.is_floating_point:
output_dtype = self.dtype
else:
output_dtype = torch.int64
out = torch.empty((), dtype=output_dtype, device=self.device)

grid = (1,)
BLOCK_SIZE = 1024
if num_diag < BLOCK_SIZE:
BLOCK_SIZE = triton.next_power_of_2(num_diag)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Consider partition the task of summing elements on the diagonal to several blocks when the number of elements to sum is large.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can treat it like a normal reduction on a vector whose stride is stride_dim0 + stride_dim1.

if BLOCK_SIZE == 0:
BLOCK_SIZE = 1
Comment on lines +82 to +83
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This check for BLOCK_SIZE == 0 is redundant and can be removed. The num_diag value is checked for 0 on line 66, and the function returns early if it is. Therefore, num_diag is guaranteed to be at least 1 at this point. Since triton.next_power_of_2(x) returns a positive integer for any positive x, BLOCK_SIZE will never be 0 here.


with torch_device_fn.device(self.device):
trace_kernel[grid](
self,
out,
num_diag,
stride0,
stride1,
BLOCK_SIZE=BLOCK_SIZE,
)

return out
37 changes: 37 additions & 0 deletions tests/test_reduction_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -846,6 +846,43 @@ def test_accuracy_inplace_scatter_mul(src_shape, inp_shape, dim, dtype):
gems_assert_close(res_out, ref_out, dtype)


TRACE_SHAPES = [
(1, 1),
(5, 5),
(10, 20),
(30, 15),
(1, 100),
(100, 1),
(128, 256),
(256, 128),
(0, 10), # empty diagonal
(10, 0), # empty diagonal
(1500, 1200), # Larger shape
]


@pytest.mark.trace
@pytest.mark.parametrize("shape", TRACE_SHAPES)
@pytest.mark.parametrize("dtype", FLOAT_DTYPES + INT_DTYPES + [torch.bool])
def test_accuracy_trace(shape, dtype):
if dtype == torch.bool:
inp = torch.randint(0, 2, size=shape, device=flag_gems.device).to(dtype)
elif dtype in INT_DTYPES:
inp = torch.randint(-100, 100, size=shape, device=flag_gems.device).to(dtype)
else:
inp = torch.randn(shape, dtype=dtype, device=flag_gems.device)

ref_inp = to_reference(inp)
ref_out = torch.trace(ref_inp)
with flag_gems.use_gems():
res_out = torch.trace(inp)

if dtype in FLOAT_DTYPES:
gems_assert_close(res_out, ref_out, dtype)
else:
gems_assert_equal(res_out, ref_out)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The test coverage is good. To make it even more robust, consider adding test cases for non-contiguous tensors, such as a transposed matrix. This would ensure that the implementation correctly handles different memory layouts via strides. You could achieve this by creating a transposed tensor within the existing test_accuracy_trace function and running the same assertions on it.



@pytest.mark.gather
@pytest.mark.parametrize(
"inp_shape",
Expand Down
Loading