diff --git a/benchmark/test_reduction_perf.py b/benchmark/test_reduction_perf.py index f28688381..8250441f3 100644 --- a/benchmark/test_reduction_perf.py +++ b/benchmark/test_reduction_perf.py @@ -258,6 +258,22 @@ def dot_input_fn(shape, dtype, device): bench.run() +@pytest.mark.trace +def test_perf_trace(): + def trace_input_fn(shape, dtype, device): + inp = generate_tensor_input(shape, dtype=dtype, device=device) + yield inp, + + bench = GenericBenchmark2DOnly( + input_fn=trace_input_fn, + op_name="trace", + torch_op=torch.trace, + dtypes=FLOAT_DTYPES + INT_DTYPES, + ) + + bench.run() + + class quantileBenchmark(GenericBenchmark): def set_more_shapes(self): more_shapes_1d = [(4,), (1024,), (65535)] diff --git a/src/flag_gems/__init__.py b/src/flag_gems/__init__.py index c7598b319..bb231dc77 100644 --- a/src/flag_gems/__init__.py +++ b/src/flag_gems/__init__.py @@ -292,6 +292,7 @@ def enable( ("tile", tile), ("to.dtype", to_dtype), ("topk", topk), + ("trace", trace), ("triu", triu), ("true_divide.Scalar", true_divide), ("true_divide.Tensor", true_divide), diff --git a/src/flag_gems/ops/__init__.py b/src/flag_gems/ops/__init__.py index 7bc001697..e1b3e8792 100755 --- a/src/flag_gems/ops/__init__.py +++ b/src/flag_gems/ops/__init__.py @@ -180,6 +180,7 @@ from flag_gems.ops.tile import tile from flag_gems.ops.to import to_dtype from flag_gems.ops.topk import topk +from flag_gems.ops.trace import trace from flag_gems.ops.triu import triu from flag_gems.ops.uniform import uniform_ from flag_gems.ops.unique import _unique2 @@ -441,6 +442,7 @@ "tile", "to_dtype", "topk", + "trace", "triu", "true_divide", "true_divide_", diff --git a/src/flag_gems/ops/trace.py b/src/flag_gems/ops/trace.py new file mode 100644 index 000000000..9d0233db2 --- /dev/null +++ b/src/flag_gems/ops/trace.py @@ -0,0 +1,95 @@ +import logging + +import torch +import triton +import triton.language as tl + +from flag_gems.runtime import torch_device_fn +from flag_gems.utils import libentry + +logger = logging.getLogger(__name__) + + +@libentry() +@triton.jit +def trace_kernel( + inp_ptr, + out_ptr, + num_diag, + stride0, + stride1, + BLOCK_SIZE: tl.constexpr, +): + inp_dtype = inp_ptr.type.element_ty + if inp_dtype.is_int(): + acc_dtype = tl.int64 + other_val = 0 + elif inp_dtype == tl.float64: + acc_dtype = tl.float64 + other_val = 0.0 + else: + acc_dtype = tl.float32 + other_val = 0.0 + + acc = tl.zeros((BLOCK_SIZE,), dtype=acc_dtype) + + diag_stride = stride0 + stride1 + + for i in range(0, tl.cdiv(num_diag, BLOCK_SIZE)): + block_start = i * BLOCK_SIZE + current_indices = block_start + tl.arange(0, BLOCK_SIZE) + + mask = current_indices < num_diag + + ptr_offsets = current_indices * diag_stride + current_ptrs = inp_ptr + ptr_offsets + + vals = tl.load(current_ptrs, mask=mask, other=other_val) + + acc += vals.to(acc_dtype) + + final_sum = tl.sum(acc, axis=0) + tl.store(out_ptr, final_sum.to(out_ptr.type.element_ty)) + + +def trace(self): + logger.debug("GEMS TRACE") + + if self.ndim != 2: + raise RuntimeError( + f"trace: expected a 2D tensor, but got a {self.ndim}D tensor" + ) + + M, N = self.shape + stride0, stride1 = self.stride() + num_diag = min(M, N) + if num_diag == 0: + if self.dtype.is_floating_point: + return torch.tensor(0.0, dtype=self.dtype, device=self.device) + else: + return torch.tensor(0, dtype=torch.int64, device=self.device) + + if self.dtype.is_floating_point: + output_dtype = self.dtype + else: + output_dtype = torch.int64 + out = torch.empty((), dtype=output_dtype, device=self.device) + + grid = (1,) + BLOCK_SIZE = 1024 + if num_diag < BLOCK_SIZE: + BLOCK_SIZE = triton.next_power_of_2(num_diag) + if BLOCK_SIZE == 0: + BLOCK_SIZE = 1 + + with torch_device_fn.device(self.device): + trace_kernel[grid]( + self, + out, + num_diag, + stride0, + stride1, + BLOCK_SIZE=BLOCK_SIZE, + ) + + return out diff --git a/tests/test_reduction_ops.py b/tests/test_reduction_ops.py index a7150e923..b60df9cd2 100644 --- a/tests/test_reduction_ops.py +++ b/tests/test_reduction_ops.py @@ -846,6 +846,46 @@ def test_accuracy_inplace_scatter_mul(src_shape, inp_shape, dim, dtype): gems_assert_close(res_out, ref_out, dtype) +TRACE_SHAPES = [ + (1, 1), + (5, 5), + (10, 20), + (30, 15), + (1, 100), + (100, 1), + (128, 256), + (256, 128), + (0, 10), # empty diagonal + (10, 0), # empty diagonal + (1500, 1200), # Larger shape +] + + +@pytest.mark.trace +@pytest.mark.parametrize("shape", TRACE_SHAPES) +@pytest.mark.parametrize("dtype", FLOAT_DTYPES + INT_DTYPES + [torch.bool]) +def test_accuracy_trace(shape, dtype): + if dtype == torch.bool: + inp = torch.randint(0, 2, size=shape, device=flag_gems.device).to(dtype) + elif dtype in INT_DTYPES: + inp = torch.randint(-100, 100, size=shape, device=flag_gems.device).to(dtype) + else: + inp = torch.randn(shape, dtype=dtype, device=flag_gems.device) + + ref_inp = to_reference(inp) + if dtype == torch.bool and ref_inp.device.type == "cpu": + pytest.skip("skipping bool on CPU reference.") + + ref_out = torch.trace(ref_inp) + with flag_gems.use_gems(): + res_out = torch.trace(inp) + + if dtype in FLOAT_DTYPES: + gems_assert_close(res_out, ref_out, dtype) + else: + gems_assert_equal(res_out, ref_out) + + @pytest.mark.gather @pytest.mark.parametrize( "inp_shape",