diff --git a/benchmark/test_reduction_perf.py b/benchmark/test_reduction_perf.py index f28688381..4dfe772b5 100644 --- a/benchmark/test_reduction_perf.py +++ b/benchmark/test_reduction_perf.py @@ -240,6 +240,103 @@ def count_nonzero_input_fn(shape, dtype, device): bench.run() +def avg_pool2d_input_fn(shape, dtype, device): + inp = generate_tensor_input(shape, dtype, device) + # Common case + yield inp, { + "kernel_size": 3, + "stride": 2, + "padding": 1, + "ceil_mode": False, + "count_include_pad": True, + "divisor_override": None, + } + if Config.bench_level == BenchLevel.COMPREHENSIVE: + # With count_include_pad=False + yield inp, { + "kernel_size": 3, + "stride": 2, + "padding": 1, + "ceil_mode": False, + "count_include_pad": False, + "divisor_override": None, + } + # With ceil_mode + yield inp, { + "kernel_size": 3, + "stride": 2, + "padding": 1, + "ceil_mode": True, + "count_include_pad": True, + "divisor_override": None, + } + # With divisor_override + if shape[-2] >= 2 and shape[-1] >= 2: + yield inp, { + "kernel_size": 2, + "stride": 1, + "padding": 0, + "ceil_mode": False, + "count_include_pad": True, + "divisor_override": 3, + } + + +class AvgPool2dBenchmark(GenericBenchmark): + def get_input_iter(self, cur_dtype) -> Generator: + shapes_4d = [ + (4, 3, 224, 224), # Typical input image size + (16, 64, 56, 56), # Early ResNet layer output + (32, 128, 28, 28), # Mid ResNet layer output + (64, 256, 14, 14), # Later ResNet layer output + (128, 512, 7, 7), # Final ResNet layer output + ] + + for shape in shapes_4d: + yield from self.input_fn(shape, cur_dtype, self.device) + + +@pytest.mark.avg_pool2d +def test_perf_avg_pool2d(): + bench = AvgPool2dBenchmark( + input_fn=avg_pool2d_input_fn, + op_name="avg_pool2d", + torch_op=torch.nn.functional.avg_pool2d, + dtypes=FLOAT_DTYPES, + ) + bench.set_gems(flag_gems.avg_pool2d) + bench.run() + + +@pytest.mark.avg_pool2d_backward +def test_perf_avg_pool2d_backward(): + def avg_pool2d_backward_input_fn(shape, dtype, device): + for forward_args in avg_pool2d_input_fn(shape, dtype, device): + inp, params = forward_args + output = torch.nn.functional.avg_pool2d(inp, **params) + grad_output = torch.randn_like(output) + inp.requires_grad_(True) + yield grad_output, inp, params + + def torch_avg_pool2d_backward_wrapper(grad_output, input, **kwargs): + output = torch.nn.functional.avg_pool2d(input, **kwargs) + grad_input = torch.autograd.grad( + outputs=(output,), inputs=(input,), grad_outputs=(grad_output,) + ) + return grad_input[0] + + bench = AvgPool2dBenchmark( + input_fn=avg_pool2d_backward_input_fn, + op_name="avg_pool2d_backward", + torch_op=torch_avg_pool2d_backward_wrapper, + dtypes=FLOAT_DTYPES, + is_backward=False, + ) + + bench.set_gems(flag_gems.avg_pool2d_backward) + bench.run() + + @pytest.mark.dot def test_perf_dot(): def dot_input_fn(shape, dtype, device): diff --git a/src/flag_gems/__init__.py b/src/flag_gems/__init__.py index b073b062b..b6b6d5093 100644 --- a/src/flag_gems/__init__.py +++ b/src/flag_gems/__init__.py @@ -64,6 +64,8 @@ def enable( ("arange.start_step", arange_start), ("argmax", argmax), ("argmin", argmin), + ("avg_pool2d", avg_pool2d), + ("avg_pool2d_backward", avg_pool2d_backward), ("atan", atan), ("atan_", atan_), ("bitwise_and.Scalar", bitwise_and_scalar), diff --git a/src/flag_gems/ops/__init__.py b/src/flag_gems/ops/__init__.py index c6d30f213..cce3b7367 100755 --- a/src/flag_gems/ops/__init__.py +++ b/src/flag_gems/ops/__init__.py @@ -18,6 +18,7 @@ flash_attn_varlen_func, scaled_dot_product_attention, ) +from flag_gems.ops.avg_pool2d import avg_pool2d, avg_pool2d_backward from flag_gems.ops.batch_norm import batch_norm, batch_norm_backward from flag_gems.ops.bitwise_and import ( bitwise_and_scalar, @@ -239,6 +240,8 @@ "arange_start", "argmax", "argmin", + "avg_pool2d", + "avg_pool2d_backward", "atan", "atan_", "batch_norm", diff --git a/src/flag_gems/ops/avg_pool2d.py b/src/flag_gems/ops/avg_pool2d.py new file mode 100644 index 000000000..4d8cddc6c --- /dev/null +++ b/src/flag_gems/ops/avg_pool2d.py @@ -0,0 +1,420 @@ +import logging + +import torch +import triton +import triton.language as tl + +from flag_gems.utils import libentry + +logger = logging.getLogger(__name__) + + +def pool2d_output_size( + in_size: int, + kernel_size: int, + stride: int, + padding: int, + dilation: int, + ceil_mode: bool = False, +) -> int: + effective_kernel_size = (kernel_size - 1) * dilation + 1 + numerator = in_size + 2 * padding - effective_kernel_size + if ceil_mode: + output_size = (numerator + stride - 1) // stride + 1 + if (output_size - 1) * stride >= in_size + padding: + output_size -= 1 + else: + output_size = numerator // stride + 1 + + return output_size + + +@libentry() +@triton.autotune( + configs=[ + triton.Config({"BLOCK_H": 16, "BLOCK_W": 16}, num_stages=4, num_warps=4), + triton.Config({"BLOCK_H": 32, "BLOCK_W": 16}, num_stages=3, num_warps=4), + triton.Config({"BLOCK_H": 16, "BLOCK_W": 32}, num_stages=3, num_warps=4), + triton.Config({"BLOCK_H": 32, "BLOCK_W": 32}, num_stages=2, num_warps=8), + triton.Config({"BLOCK_H": 8, "BLOCK_W": 8}, num_stages=5, num_warps=2), + triton.Config({"BLOCK_H": 8, "BLOCK_W": 16}, num_stages=5, num_warps=2), + triton.Config({"BLOCK_H": 16, "BLOCK_W": 8}, num_stages=5, num_warps=2), + triton.Config({"BLOCK_H": 64, "BLOCK_W": 16}, num_stages=2, num_warps=8), + triton.Config({"BLOCK_H": 16, "BLOCK_W": 64}, num_stages=2, num_warps=8), + ], + key=["out_h", "out_w", "kernel_h", "kernel_w", "stride_h", "stride_w"], +) +@triton.jit +def avg_pool2d_forward_kernel( + input_ptr, + output_ptr, + # Input tensor strides + in_stride_n, + in_stride_c, + in_stride_h, + in_stride_w, + # Input/Output shapes + in_c, + in_h, + in_w, + out_h, + out_w, + # Pooling parameters + kernel_h: tl.constexpr, + kernel_w: tl.constexpr, + stride_h: tl.constexpr, + stride_w: tl.constexpr, + padding_h: tl.constexpr, + padding_w: tl.constexpr, + dilation_h: tl.constexpr, + dilation_w: tl.constexpr, + # AvgPool specific parameters + COUNT_INCLUDE_PAD: tl.constexpr, + divisor_override, + # Tiling meta-parameters + BLOCK_H: tl.constexpr, + BLOCK_W: tl.constexpr, +): + pid_nc = tl.program_id(0) + pid_hw = tl.program_id(1) + num_w_blocks = tl.cdiv(out_w, BLOCK_W) + h_block_idx = pid_hw // num_w_blocks + w_block_idx = pid_hw % num_w_blocks + n_idx = pid_nc // in_c + c_idx = pid_nc % in_c + + h_out_offsets = h_block_idx * BLOCK_H + tl.arange(0, BLOCK_H) + w_out_offsets = w_block_idx * BLOCK_W + tl.arange(0, BLOCK_W) + + sum_acc = tl.zeros((BLOCK_H, BLOCK_W), dtype=tl.float32) + count_acc = tl.zeros((BLOCK_H, BLOCK_W), dtype=tl.int32) + + input_base_ptr = input_ptr + n_idx * in_stride_n + c_idx * in_stride_c + + for kh in range(0, kernel_h): + for kw in range(0, kernel_w): + h_in = h_out_offsets[:, None] * stride_h - padding_h + kh * dilation_h + w_in = w_out_offsets[None, :] * stride_w - padding_w + kw * dilation_w + in_mask = (h_in >= 0) & (h_in < in_h) & (w_in >= 0) & (w_in < in_w) + + input_offset = h_in * in_stride_h + w_in * in_stride_w + current_val = tl.load( + input_base_ptr + input_offset, mask=in_mask, other=0.0 + ) + + sum_acc += tl.where(in_mask, current_val, 0.0) + count_acc += in_mask.to(tl.int32) + + if divisor_override != 0: + divisor = tl.full((BLOCK_H, BLOCK_W), divisor_override, dtype=tl.float32) + elif COUNT_INCLUDE_PAD: + divisor = tl.full((BLOCK_H, BLOCK_W), kernel_h * kernel_w, dtype=tl.float32) + else: + divisor = count_acc.to(tl.float32) + + output_vals = tl.where(divisor != 0, sum_acc / divisor, 0.0) + + out_base_ptr = output_ptr + pid_nc * out_h * out_w + out_h_offsets = h_block_idx * BLOCK_H + tl.arange(0, BLOCK_H) + out_w_offsets = w_block_idx * BLOCK_W + tl.arange(0, BLOCK_W) + output_block_ptr = ( + out_base_ptr + out_h_offsets[:, None] * out_w + out_w_offsets[None, :] + ) + + out_mask = (out_h_offsets[:, None] < out_h) & (out_w_offsets[None, :] < out_w) + tl.store( + output_block_ptr, output_vals.to(output_ptr.type.element_ty), mask=out_mask + ) + + +@libentry() +@triton.autotune( + configs=[ + triton.Config({"BLOCK_H": 16, "BLOCK_W": 16}, num_stages=4, num_warps=4), + triton.Config({"BLOCK_H": 32, "BLOCK_W": 16}, num_stages=3, num_warps=4), + triton.Config({"BLOCK_H": 16, "BLOCK_W": 32}, num_stages=3, num_warps=4), + triton.Config({"BLOCK_H": 32, "BLOCK_W": 32}, num_stages=2, num_warps=8), + triton.Config({"BLOCK_H": 64, "BLOCK_W": 32}, num_stages=2, num_warps=8), + triton.Config({"BLOCK_H": 32, "BLOCK_W": 64}, num_stages=2, num_warps=8), + ], + key=["in_h", "in_w", "kernel_h", "kernel_w", "stride_h", "stride_w"], +) +@triton.jit +def avg_pool2d_backward_kernel( + grad_output_ptr, + grad_input_ptr, + # Input/Output shapes + in_c, + in_h, + in_w, + out_h, + out_w, + # Strides + in_stride_n, + in_stride_c, + in_stride_h, + in_stride_w, + out_stride_n, + out_stride_c, + out_stride_h, + out_stride_w, + # Pooling parameters + kernel_h: tl.constexpr, + kernel_w: tl.constexpr, + stride_h: tl.constexpr, + stride_w: tl.constexpr, + padding_h: tl.constexpr, + padding_w: tl.constexpr, + dilation_h: tl.constexpr, + dilation_w: tl.constexpr, + # AvgPool specific parameters + COUNT_INCLUDE_PAD: tl.constexpr, + divisor_override, + # Tiling meta-parameters + BLOCK_H: tl.constexpr, + BLOCK_W: tl.constexpr, +): + pid_nc = tl.program_id(0) + pid_hw = tl.program_id(1) + + num_w_blocks = tl.cdiv(in_w, BLOCK_W) + + h_block_idx = pid_hw // num_w_blocks + w_block_idx = pid_hw % num_w_blocks + n_idx = pid_nc // in_c + c_idx = pid_nc % in_c + + grad_input_block_ptr = grad_input_ptr + n_idx * in_stride_n + c_idx * in_stride_c + grad_output_base_ptr = grad_output_ptr + n_idx * out_stride_n + c_idx * out_stride_c + + h_in_offsets = h_block_idx * BLOCK_H + tl.arange(0, BLOCK_H) + w_in_offsets = w_block_idx * BLOCK_W + tl.arange(0, BLOCK_W) + + grad_acc = tl.zeros((BLOCK_H, BLOCK_W), dtype=tl.float32) + + for kh_loop in range(kernel_h): + for kw_loop in range(kernel_w): + h_out_num = h_in_offsets[:, None] + padding_h - kh_loop * dilation_h + w_out_num = w_in_offsets[None, :] + padding_w - kw_loop * dilation_w + + h_valid_map = (h_out_num >= 0) & ((h_out_num % stride_h) == 0) + w_valid_map = (w_out_num >= 0) & ((w_out_num % stride_w) == 0) + + h_out = h_out_num // stride_h + w_out = w_out_num // stride_w + + h_out_mask = h_valid_map & (h_out < out_h) + w_out_mask = w_valid_map & (w_out < out_w) + out_mask = h_out_mask & w_out_mask + + if divisor_override != 0: + divisor = tl.full( + (BLOCK_H, BLOCK_W), divisor_override, dtype=tl.float32 + ) + elif COUNT_INCLUDE_PAD: + divisor = tl.full( + (BLOCK_H, BLOCK_W), kernel_h * kernel_w, dtype=tl.float32 + ) + else: + h_start = h_out * stride_h - padding_h + w_start = w_out * stride_w - padding_w + count = tl.zeros((BLOCK_H, BLOCK_W), dtype=tl.int32) + for kh_count in range(0, kernel_h): + for kw_count in range(0, kernel_w): + h_in_for_count = h_start + kh_count * dilation_h + w_in_for_count = w_start + kw_count * dilation_w + is_valid = ( + (h_in_for_count >= 0) + & (h_in_for_count < in_h) + & (w_in_for_count >= 0) + & (w_in_for_count < in_w) + ) + count += is_valid.to(tl.int32) + divisor = count.to(tl.float32) + + divisor = tl.where(divisor == 0, 1.0, divisor) + + grad_out_ptr = ( + grad_output_base_ptr + h_out * out_stride_h + w_out * out_stride_w + ) + grad_out_val = tl.load(grad_out_ptr, mask=out_mask, other=0.0) + grad_acc += tl.where(out_mask, grad_out_val / divisor, 0.0) + # grad_to_add = grad_out_val.to(tl.float32) / divisor.to(tl.float32) + # grad_acc += tl.where(out_mask, grad_to_add, 0.0) + + grad_input_store_ptr = ( + grad_input_block_ptr + + h_in_offsets[:, None] * in_stride_h + + w_in_offsets[None, :] * in_stride_w + ) + in_write_mask = (h_in_offsets[:, None] < in_h) & (w_in_offsets[None, :] < in_w) + tl.store( + grad_input_store_ptr, + grad_acc.to(grad_input_ptr.type.element_ty), + mask=in_write_mask, + ) + + +def _parse_pool_params(kernel_size, stride, padding): + if isinstance(kernel_size, int): + kernel_h = kernel_w = kernel_size + else: + kernel_h, kernel_w = kernel_size + + if stride is None or (isinstance(stride, (list, tuple)) and not stride): + stride_h, stride_w = kernel_h, kernel_w + elif isinstance(stride, int): + stride_h = stride_w = stride + else: + stride_h, stride_w = stride + + if isinstance(padding, int): + padding_h = padding_w = padding + else: + padding_h, padding_w = padding + + if stride_h <= 0 or stride_w <= 0: + raise ValueError("stride must be greater than zero") + + if padding_h < 0 or padding_w < 0: + raise ValueError("padding must be non-negative") + + if padding_h > kernel_h // 2 or padding_w > kernel_w // 2: + raise ValueError("pad should be smaller than or equal to half of kernel size") + + return kernel_h, kernel_w, stride_h, stride_w, padding_h, padding_w + + +def avg_pool2d( + input: torch.Tensor, + kernel_size, + stride=None, + padding=0, + ceil_mode=False, + count_include_pad=True, + divisor_override=None, +): + logger.debug("GEMS AVG_POOL2D FORWARD") + + if divisor_override is not None and divisor_override == 0: + raise ValueError("divisor_override cannot be zero") + + input = input.contiguous() + + kernel_h, kernel_w, stride_h, stride_w, padding_h, padding_w = _parse_pool_params( + kernel_size, stride, padding + ) + dilation_h, dilation_w = 1, 1 + + in_n, in_c, in_h, in_w = input.shape + + out_h = pool2d_output_size( + in_h, kernel_h, stride_h, padding_h, dilation_h, ceil_mode + ) + out_w = pool2d_output_size( + in_w, kernel_w, stride_w, padding_w, dilation_w, ceil_mode + ) + + output = torch.empty( + (in_n, in_c, out_h, out_w), device=input.device, dtype=input.dtype + ) + + if output.numel() == 0: + return output + + grid = lambda meta: ( + in_n * in_c, + triton.cdiv(out_h, meta["BLOCK_H"]) * triton.cdiv(out_w, meta["BLOCK_W"]), + ) + + avg_pool2d_forward_kernel[grid]( + input, + output, + input.stride(0), + input.stride(1), + input.stride(2), + input.stride(3), + in_c, + in_h, + in_w, + out_h, + out_w, + kernel_h, + kernel_w, + stride_h, + stride_w, + padding_h, + padding_w, + dilation_h, + dilation_w, + COUNT_INCLUDE_PAD=count_include_pad, + divisor_override=divisor_override if divisor_override is not None else 0.0, + ) + + return output + + +def avg_pool2d_backward( + grad_output: torch.Tensor, + input: torch.Tensor, + kernel_size, + stride, + padding, + ceil_mode, + count_include_pad, + divisor_override, +): + logger.debug("GEMS AVG_POOL2D BACKWARD") + + if divisor_override is not None and divisor_override == 0: + raise ValueError("divisor_override cannot be zero") + + grad_output = grad_output.contiguous() + + kernel_h, kernel_w, stride_h, stride_w, padding_h, padding_w = _parse_pool_params( + kernel_size, stride, padding + ) + dilation_h, dilation_w = 1, 1 + + in_n, in_c, in_h, in_w = input.shape + out_h, out_w = grad_output.shape[2], grad_output.shape[3] + + grad_input = torch.zeros_like(input, dtype=torch.float32) + + if grad_output.numel() == 0: + return grad_input.to(grad_output.dtype) + + grid = lambda meta: ( + in_n * in_c, + triton.cdiv(in_h, meta["BLOCK_H"]) * triton.cdiv(in_w, meta["BLOCK_W"]), + ) + + avg_pool2d_backward_kernel[grid]( + grad_output, + grad_input, + in_c, + in_h, + in_w, + out_h, + out_w, + grad_input.stride(0), + grad_input.stride(1), + grad_input.stride(2), + grad_input.stride(3), + grad_output.stride(0), + grad_output.stride(1), + grad_output.stride(2), + grad_output.stride(3), + kernel_h, + kernel_w, + stride_h, + stride_w, + padding_h, + padding_w, + dilation_h, + dilation_w, + COUNT_INCLUDE_PAD=count_include_pad, + divisor_override=divisor_override if divisor_override is not None else 0.0, + ) + + return grad_input.to(grad_output.dtype) diff --git a/tests/test_reduction_ops.py b/tests/test_reduction_ops.py index a434b9e85..2c08fc4b7 100644 --- a/tests/test_reduction_ops.py +++ b/tests/test_reduction_ops.py @@ -1110,6 +1110,74 @@ def test_accuracy_masked_select(shape, dtype, threshold): gems_assert_equal(res_out, ref_out) +AVGPOOL2D_CONFIGS = [ + # 3x3 kernel, stride 2, padding 1 + ((4, 3, 32, 32), 3, 2, 1, False, True, None), + # Test count_include_pad=False + ((4, 3, 32, 32), 3, 2, 1, False, False, None), + # Non-square kernel and stride + ((8, 16, 28, 28), (3, 5), (1, 2), 1, False, True, None), + # Test ceil_mode + ((2, 4, 15, 15), 3, 2, 1, True, True, None), + # Test divisor_override + ((1, 1, 7, 7), 2, 1, 0, False, True, 1), + # Larger case from a typical CNN + ((1, 64, 56, 56), 3, 2, 1, False, True, None), + # No padding, count_include_pad=False + ((2, 8, 16, 16), 2, 2, 0, False, False, None), + # Non-square padding + ((2, 8, 16, 20), 2, 2, (1, 0), False, True, None), +] + + +@pytest.mark.avg_pool2d +@pytest.mark.parametrize( + "shape, kernel_size, stride, padding, ceil_mode, count_include_pad, divisor_override", + AVGPOOL2D_CONFIGS, +) +@pytest.mark.parametrize("dtype", FLOAT_DTYPES) +def test_forward_avg_pool2d( + shape, + kernel_size, + stride, + padding, + ceil_mode, + count_include_pad, + divisor_override, + dtype, +): + inp = torch.randn(shape, dtype=dtype, device=flag_gems.device, requires_grad=True) + ref_inp = to_reference(inp, True) + + ref_out = torch.nn.functional.avg_pool2d( + ref_inp, + kernel_size=kernel_size, + stride=stride, + padding=padding, + ceil_mode=ceil_mode, + count_include_pad=count_include_pad, + divisor_override=divisor_override, + ) + + res_out = flag_gems.avg_pool2d( + inp, + kernel_size=kernel_size, + stride=stride, + padding=padding, + ceil_mode=ceil_mode, + count_include_pad=count_include_pad, + divisor_override=divisor_override, + ) + + gems_assert_close(res_out, ref_out, dtype) + + out_grad = torch.randn_like(res_out, device=flag_gems.device) + ref_grad = to_reference(out_grad, True) + (ref_in_grad,) = torch.autograd.grad(ref_out, ref_inp, ref_grad) + (res_in_grad,) = torch.autograd.grad(res_out, inp, out_grad) + gems_assert_close(res_in_grad, ref_in_grad, dtype) + + SHAPE_CONV1D = [ ((32, 2, 4), (17, 2, 2)), ((32, 15, 6), (17, 15, 2)),