-
Notifications
You must be signed in to change notification settings - Fork 145
Add trace #992
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Add trace #992
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,95 @@ | ||
import logging | ||
|
||
import torch | ||
import triton | ||
import triton.language as tl | ||
|
||
from flag_gems.runtime import torch_device_fn | ||
from flag_gems.utils import libentry | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
||
@libentry() | ||
@triton.jit | ||
def trace_kernel( | ||
inp_ptr, | ||
out_ptr, | ||
num_diag, | ||
stride0, | ||
stride1, | ||
BLOCK_SIZE: tl.constexpr, | ||
): | ||
inp_dtype = inp_ptr.type.element_ty | ||
if inp_dtype.is_int(): | ||
acc_dtype = tl.int64 | ||
other_val = 0 | ||
elif inp_dtype == tl.float64: | ||
acc_dtype = tl.float64 | ||
other_val = 0.0 | ||
else: | ||
acc_dtype = tl.float32 | ||
other_val = 0.0 | ||
|
||
acc = tl.zeros((BLOCK_SIZE,), dtype=acc_dtype) | ||
|
||
diag_stride = stride0 + stride1 | ||
|
||
for i in range(0, tl.cdiv(num_diag, BLOCK_SIZE)): | ||
block_start = i * BLOCK_SIZE | ||
current_indices = block_start + tl.arange(0, BLOCK_SIZE) | ||
|
||
mask = current_indices < num_diag | ||
|
||
ptr_offsets = current_indices * diag_stride | ||
current_ptrs = inp_ptr + ptr_offsets | ||
|
||
vals = tl.load(current_ptrs, mask=mask, other=other_val) | ||
|
||
acc += vals.to(acc_dtype) | ||
|
||
final_sum = tl.sum(acc, axis=0) | ||
tl.store(out_ptr, final_sum.to(out_ptr.type.element_ty)) | ||
|
||
|
||
def trace(self): | ||
logger.debug("GEMS TRACE") | ||
|
||
if self.ndim != 2: | ||
raise RuntimeError( | ||
f"trace: expected a 2D tensor, but got a {self.ndim}D tensor" | ||
) | ||
|
||
M, N = self.shape | ||
stride0, stride1 = self.stride() | ||
num_diag = min(M, N) | ||
if num_diag == 0: | ||
if self.dtype.is_floating_point: | ||
return torch.tensor(0.0, dtype=self.dtype, device=self.device) | ||
else: | ||
return torch.tensor(0, dtype=torch.int64, device=self.device) | ||
|
||
if self.dtype.is_floating_point: | ||
output_dtype = self.dtype | ||
else: | ||
output_dtype = torch.int64 | ||
out = torch.empty((), dtype=output_dtype, device=self.device) | ||
|
||
grid = (1,) | ||
BLOCK_SIZE = 1024 | ||
if num_diag < BLOCK_SIZE: | ||
BLOCK_SIZE = triton.next_power_of_2(num_diag) | ||
if BLOCK_SIZE == 0: | ||
BLOCK_SIZE = 1 | ||
Comment on lines
+82
to
+83
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This check for |
||
|
||
with torch_device_fn.device(self.device): | ||
trace_kernel[grid]( | ||
self, | ||
out, | ||
num_diag, | ||
stride0, | ||
stride1, | ||
BLOCK_SIZE=BLOCK_SIZE, | ||
) | ||
|
||
return out |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -846,6 +846,43 @@ def test_accuracy_inplace_scatter_mul(src_shape, inp_shape, dim, dtype): | |
gems_assert_close(res_out, ref_out, dtype) | ||
|
||
|
||
TRACE_SHAPES = [ | ||
(1, 1), | ||
(5, 5), | ||
(10, 20), | ||
(30, 15), | ||
(1, 100), | ||
(100, 1), | ||
(128, 256), | ||
(256, 128), | ||
(0, 10), # empty diagonal | ||
(10, 0), # empty diagonal | ||
(1500, 1200), # Larger shape | ||
] | ||
|
||
|
||
@pytest.mark.trace | ||
@pytest.mark.parametrize("shape", TRACE_SHAPES) | ||
@pytest.mark.parametrize("dtype", FLOAT_DTYPES + INT_DTYPES + [torch.bool]) | ||
def test_accuracy_trace(shape, dtype): | ||
if dtype == torch.bool: | ||
inp = torch.randint(0, 2, size=shape, device=flag_gems.device).to(dtype) | ||
elif dtype in INT_DTYPES: | ||
inp = torch.randint(-100, 100, size=shape, device=flag_gems.device).to(dtype) | ||
else: | ||
inp = torch.randn(shape, dtype=dtype, device=flag_gems.device) | ||
|
||
ref_inp = to_reference(inp) | ||
ref_out = torch.trace(ref_inp) | ||
with flag_gems.use_gems(): | ||
res_out = torch.trace(inp) | ||
|
||
if dtype in FLOAT_DTYPES: | ||
gems_assert_close(res_out, ref_out, dtype) | ||
else: | ||
gems_assert_equal(res_out, ref_out) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The test coverage is good. To make it even more robust, consider adding test cases for non-contiguous tensors, such as a transposed matrix. This would ensure that the implementation correctly handles different memory layouts via strides. You could achieve this by creating a transposed tensor within the existing |
||
|
||
|
||
@pytest.mark.gather | ||
@pytest.mark.parametrize( | ||
"inp_shape", | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Consider partition the task of summing elements on the diagonal to several blocks when the number of elements to sum is large.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You can treat it like a normal reduction on a vector whose stride is stride_dim0 + stride_dim1.