From 715e7c448c5107aef2de70334c7c2fd10c2a4b7c Mon Sep 17 00:00:00 2001 From: root Date: Sun, 14 Sep 2025 06:19:31 +0000 Subject: [PATCH 1/2] fix issue --- benchmark/test_reduction_perf.py | 80 ++++++ src/flag_gems/__init__.py | 1 + src/flag_gems/ops/__init__.py | 2 + src/flag_gems/ops/avg_pool2d.py | 426 +++++++++++++++++++++++++++++++ tests/test_reduction_ops.py | 54 ++++ 5 files changed, 563 insertions(+) create mode 100644 src/flag_gems/ops/avg_pool2d.py diff --git a/benchmark/test_reduction_perf.py b/benchmark/test_reduction_perf.py index f28688381..8f321f569 100644 --- a/benchmark/test_reduction_perf.py +++ b/benchmark/test_reduction_perf.py @@ -240,6 +240,86 @@ def count_nonzero_input_fn(shape, dtype, device): bench.run() +def avg_pool2d_input_fn(shape, dtype, device): + inp = generate_tensor_input(shape, dtype, device) + # Common case + yield inp, { + "kernel_size": 3, + "stride": 2, + "padding": 1, + "ceil_mode": False, + "count_include_pad": True, + "divisor_override": None, + } + if Config.bench_level == BenchLevel.COMPREHENSIVE: + # With count_include_pad=False + yield inp, { + "kernel_size": 3, + "stride": 2, + "padding": 1, + "ceil_mode": False, + "count_include_pad": False, + "divisor_override": None, + } + # With ceil_mode + yield inp, { + "kernel_size": 3, + "stride": 2, + "padding": 1, + "ceil_mode": True, + "count_include_pad": True, + "divisor_override": None, + } + # With divisor_override + if shape[-2] >= 2 and shape[-1] >= 2: + yield inp, { + "kernel_size": 2, + "stride": 1, + "padding": 0, + "ceil_mode": False, + "count_include_pad": True, + "divisor_override": 3, + } + +class AvgPool2dBenchmark(GenericBenchmark): + def get_input_iter(self, cur_dtype) -> Generator: + shapes_4d = [ + (4, 3, 224, 224), # Typical input image size + (16, 64, 56, 56), # Early ResNet layer output + (32, 128, 28, 28), # Mid ResNet layer output + (64, 256, 14, 14), # Later ResNet layer output + (128, 512, 7, 7), # Final ResNet layer output + ] + + for shape in shapes_4d: + yield from self.input_fn(shape, cur_dtype, self.device) + + +@pytest.mark.avg_pool2d +def test_perf_avg_pool2d(): + bench = AvgPool2dBenchmark( + input_fn=avg_pool2d_input_fn, + op_name="avg_pool2d", + torch_op=torch.nn.functional.avg_pool2d, + dtypes=FLOAT_DTYPES, + ) + bench.set_gems(flag_gems.avg_pool2d) + bench.run() + + +@pytest.mark.avg_pool2d_backward +def test_perf_avg_pool2d_backward(): + bench = AvgPool2dBenchmark( + input_fn=avg_pool2d_input_fn, + op_name="avg_pool2d", + torch_op=torch.nn.functional.avg_pool2d, + dtypes=FLOAT_DTYPES, + is_backward=True, + ) + bench.set_gems(flag_gems.avg_pool2d) + bench.run() + + @pytest.mark.dot def test_perf_dot(): def dot_input_fn(shape, dtype, device): diff --git a/src/flag_gems/__init__.py b/src/flag_gems/__init__.py index c7598b319..fad8c9c2c 100644 --- a/src/flag_gems/__init__.py +++ b/src/flag_gems/__init__.py @@ -62,6 +62,7 @@ def enable( ("arange.start_step", arange_start), ("argmax", argmax), ("argmin", argmin), + ("avg_pool2d", avg_pool2d), ("bitwise_and.Scalar", bitwise_and_scalar), ("bitwise_and.Scalar_Tensor", bitwise_and_scalar_tensor), ("bitwise_and.Tensor", bitwise_and_tensor), diff --git a/src/flag_gems/ops/__init__.py b/src/flag_gems/ops/__init__.py index 7bc001697..b46ccdb82 100755 --- a/src/flag_gems/ops/__init__.py +++ b/src/flag_gems/ops/__init__.py @@ -16,6 +16,7 @@ flash_attn_varlen_func, scaled_dot_product_attention, ) +from flag_gems.ops.avg_pool2d import avg_pool2d from flag_gems.ops.batch_norm import batch_norm, batch_norm_backward from flag_gems.ops.bitwise_and import ( bitwise_and_scalar, @@ -226,6 +227,7 @@ "arange_start", "argmax", "argmin", + "avg_pool2d", "batch_norm", "batch_norm_backward", "bitwise_and_scalar", diff --git a/src/flag_gems/ops/avg_pool2d.py b/src/flag_gems/ops/avg_pool2d.py new file mode 100644 index 000000000..64ed2adec --- /dev/null +++ b/src/flag_gems/ops/avg_pool2d.py @@ -0,0 +1,426 @@ +import logging + +import torch +import triton +import triton.language as tl + +from flag_gems.utils import libentry + +logger = logging.getLogger(__name__) + + +def pool2d_output_size( + in_size: int, + kernel_size: int, + stride: int, + padding: int, + dilation: int, + ceil_mode: bool = False, +) -> int: + effective_kernel_size = (kernel_size - 1) * dilation + 1 + numerator = in_size + 2 * padding - effective_kernel_size + if ceil_mode: + return (numerator + stride - 1) // stride + 1 + else: + return numerator // stride + 1 + + +@libentry() +@triton.autotune( + configs=[ + triton.Config({"BLOCK_H": 16, "BLOCK_W": 16}, num_stages=4, num_warps=4), + triton.Config({"BLOCK_H": 32, "BLOCK_W": 16}, num_stages=3, num_warps=4), + triton.Config({"BLOCK_H": 16, "BLOCK_W": 32}, num_stages=3, num_warps=4), + triton.Config({"BLOCK_H": 32, "BLOCK_W": 32}, num_stages=2, num_warps=8), + triton.Config({"BLOCK_H": 8, "BLOCK_W": 8}, num_stages=5, num_warps=2), + triton.Config({"BLOCK_H": 8, "BLOCK_W": 16}, num_stages=5, num_warps=2), + triton.Config({"BLOCK_H": 16, "BLOCK_W": 8}, num_stages=5, num_warps=2), + triton.Config({"BLOCK_H": 64, "BLOCK_W": 16}, num_stages=2, num_warps=8), + triton.Config({"BLOCK_H": 16, "BLOCK_W": 64}, num_stages=2, num_warps=8), + ], + key=["out_h", "out_w", "kernel_h", "kernel_w", "stride_h", "stride_w"], +) +@triton.jit +def avg_pool2d_forward_kernel( + input_ptr, + output_ptr, + # Input tensor strides + in_stride_n, + in_stride_c, + in_stride_h, + in_stride_w, + # Input/Output shapes + in_c, + in_h, + in_w, + out_h, + out_w, + # Pooling parameters + kernel_h: tl.constexpr, + kernel_w: tl.constexpr, + stride_h: tl.constexpr, + stride_w: tl.constexpr, + padding_h: tl.constexpr, + padding_w: tl.constexpr, + dilation_h: tl.constexpr, + dilation_w: tl.constexpr, + # AvgPool specific parameters + COUNT_INCLUDE_PAD: tl.constexpr, + divisor_override: tl.constexpr, + # Tiling meta-parameters + BLOCK_H: tl.constexpr, + BLOCK_W: tl.constexpr, +): + pid_nc = tl.program_id(0) + pid_hw = tl.program_id(1) + num_w_blocks = tl.cdiv(out_w, BLOCK_W) + h_block_idx = pid_hw // num_w_blocks + w_block_idx = pid_hw % num_w_blocks + n_idx = pid_nc // in_c + c_idx = pid_nc % in_c + + h_out_offsets = h_block_idx * BLOCK_H + tl.arange(0, BLOCK_H) + w_out_offsets = w_block_idx * BLOCK_W + tl.arange(0, BLOCK_W) + + sum_acc = tl.zeros((BLOCK_H, BLOCK_W), dtype=tl.float32) + count_acc = tl.zeros((BLOCK_H, BLOCK_W), dtype=tl.int32) + + input_base_ptr = input_ptr + n_idx * in_stride_n + c_idx * in_stride_c + + for kh in range(0, kernel_h): + for kw in range(0, kernel_w): + h_in = h_out_offsets[:, None] * stride_h - padding_h + kh * dilation_h + w_in = w_out_offsets[None, :] * stride_w - padding_w + kw * dilation_w + in_mask = (h_in >= 0) & (h_in < in_h) & (w_in >= 0) & (w_in < in_w) + + input_offset = h_in * in_stride_h + w_in * in_stride_w + current_val = tl.load( + input_base_ptr + input_offset, mask=in_mask, other=0.0 + ) + + sum_acc += tl.where(in_mask, current_val, 0.0) + count_acc += in_mask.to(tl.int32) + + if divisor_override > 0: + divisor = divisor_override + elif COUNT_INCLUDE_PAD: + divisor = kernel_h * kernel_w + else: + divisor = count_acc + + output_vals = tl.where(divisor != 0, sum_acc / divisor, 0.0) + + out_base_ptr = output_ptr + pid_nc * out_h * out_w + out_h_offsets = h_block_idx * BLOCK_H + tl.arange(0, BLOCK_H) + out_w_offsets = w_block_idx * BLOCK_W + tl.arange(0, BLOCK_W) + output_block_ptr = ( + out_base_ptr + out_h_offsets[:, None] * out_w + out_w_offsets[None, :] + ) + + out_mask = (out_h_offsets[:, None] < out_h) & (out_w_offsets[None, :] < out_w) + tl.store( + output_block_ptr, output_vals.to(output_ptr.type.element_ty), mask=out_mask + ) + + +@libentry() +@triton.autotune( + configs=[ + triton.Config({"BLOCK_H": 16, "BLOCK_W": 16}, num_stages=4, num_warps=4), + triton.Config({"BLOCK_H": 32, "BLOCK_W": 16}, num_stages=3, num_warps=4), + triton.Config({"BLOCK_H": 16, "BLOCK_W": 32}, num_stages=3, num_warps=4), + triton.Config({"BLOCK_H": 32, "BLOCK_W": 32}, num_stages=2, num_warps=8), + triton.Config({"BLOCK_H": 64, "BLOCK_W": 32}, num_stages=2, num_warps=8), + triton.Config({"BLOCK_H": 32, "BLOCK_W": 64}, num_stages=2, num_warps=8), + ], + key=["in_h", "in_w", "kernel_h", "kernel_w", "stride_h", "stride_w"], +) +@triton.jit +def avg_pool2d_backward_kernel( + grad_output_ptr, + grad_input_ptr, + # Input/Output shapes + in_c, + in_h, + in_w, + out_h, + out_w, + # Strides + in_stride_n, + in_stride_c, + in_stride_h, + in_stride_w, + out_stride_n, + out_stride_c, + out_stride_h, + out_stride_w, + # Pooling parameters + kernel_h: tl.constexpr, + kernel_w: tl.constexpr, + stride_h: tl.constexpr, + stride_w: tl.constexpr, + padding_h: tl.constexpr, + padding_w: tl.constexpr, + dilation_h: tl.constexpr, + dilation_w: tl.constexpr, + # AvgPool specific parameters + COUNT_INCLUDE_PAD: tl.constexpr, + divisor_override: tl.constexpr, + # Tiling meta-parameters + BLOCK_H: tl.constexpr, + BLOCK_W: tl.constexpr, +): + # Each program computes a block of grad_input. + pid_nc = tl.program_id(0) + pid_hw = tl.program_id(1) + + num_w_blocks = tl.cdiv(in_w, BLOCK_W) + + h_block_idx = pid_hw // num_w_blocks + w_block_idx = pid_hw % num_w_blocks + n_idx = pid_nc // in_c + c_idx = pid_nc % in_c + + grad_input_block_ptr = grad_input_ptr + n_idx * in_stride_n + c_idx * in_stride_c + grad_output_base_ptr = grad_output_ptr + n_idx * out_stride_n + c_idx * out_stride_c + + h_in_offsets = h_block_idx * BLOCK_H + tl.arange(0, BLOCK_H) + w_in_offsets = w_block_idx * BLOCK_W + tl.arange(0, BLOCK_W) + + grad_acc = tl.zeros((BLOCK_H, BLOCK_W), dtype=tl.float32) + + for kh_loop in range(kernel_h): + for kw_loop in range(kernel_w): + h_out_num = h_in_offsets[:, None] + padding_h - kh_loop * dilation_h + w_out_num = w_in_offsets[None, :] + padding_w - kw_loop * dilation_w + + h_valid_map = (h_out_num >= 0) & ((h_out_num % stride_h) == 0) + w_valid_map = (w_out_num >= 0) & ((w_out_num % stride_w) == 0) + + h_out = h_out_num // stride_h + w_out = w_out_num // stride_w + + h_out_mask = h_valid_map & (h_out < out_h) + w_out_mask = w_valid_map & (w_out < out_w) + out_mask = h_out_mask & w_out_mask + + if divisor_override > 0: + divisor = tl.full( + (BLOCK_H, BLOCK_W), divisor_override, dtype=tl.float32 + ) + elif COUNT_INCLUDE_PAD: + divisor = tl.full( + (BLOCK_H, BLOCK_W), kernel_h * kernel_w, dtype=tl.float32 + ) + else: + # Re-compute count for the divisor when padding is not included. + h_start = h_out * stride_h - padding_h + w_start = w_out * stride_w - padding_w + count = tl.zeros((BLOCK_H, BLOCK_W), dtype=tl.int32) + for kh_count in range(0, kernel_h): + for kw_count in range(0, kernel_w): + h_in_for_count = h_start + kh_count * dilation_h + w_in_for_count = w_start + kw_count * dilation_w + is_valid = ( + (h_in_for_count >= 0) + & (h_in_for_count < in_h) + & (w_in_for_count >= 0) + & (w_in_for_count < in_w) + ) + count += is_valid.to(tl.int32) + divisor = count.to(tl.float32) + + divisor = tl.where(divisor == 0, 1.0, divisor) + + grad_out_ptr = ( + grad_output_base_ptr + h_out * out_stride_h + w_out * out_stride_w + ) + grad_out_val = tl.load(grad_out_ptr, mask=out_mask, other=0.0) + grad_to_add = grad_out_val / divisor + grad_acc += tl.where(out_mask, grad_to_add, 0.0) + + grad_input_store_ptr = ( + grad_input_block_ptr + + h_in_offsets[:, None] * in_stride_h + + w_in_offsets[None, :] * in_stride_w + ) + in_write_mask = (h_in_offsets[:, None] < in_h) & (w_in_offsets[None, :] < in_w) + tl.store( + grad_input_store_ptr, + grad_acc.to(grad_input_ptr.type.element_ty), + mask=in_write_mask, + ) + + +class AvgPool2d(torch.autograd.Function): + @staticmethod + def forward( + ctx, + input, + kernel_size, + stride, + padding, + ceil_mode, + count_include_pad, + divisor_override, + ): + logger.debug("GEMS AVG_POOL2D FORWARD") + input = input.contiguous() + + if isinstance(kernel_size, int): + kernel_h = kernel_w = kernel_size + else: + kernel_h, kernel_w = kernel_size + if stride is None or (isinstance(stride, (list, tuple)) and not stride): + stride_h, stride_w = kernel_h, kernel_w + elif isinstance(stride, int): + stride_h = stride_w = stride + else: + stride_h, stride_w = stride + if isinstance(padding, int): + padding_h = padding_w = padding + else: + padding_h, padding_w = padding + + dilation_h, dilation_w = 1, 1 + + in_n, in_c, in_h, in_w = input.shape + out_h = pool2d_output_size( + in_h, kernel_h, stride_h, padding_h, dilation_h, ceil_mode + ) + out_w = pool2d_output_size( + in_w, kernel_w, stride_w, padding_w, dilation_w, ceil_mode + ) + + output = torch.empty( + (in_n, in_c, out_h, out_w), device=input.device, dtype=input.dtype + ) + + if output.numel() > 0: + grid = lambda meta: ( + in_n * in_c, + triton.cdiv(out_h, meta["BLOCK_H"]) + * triton.cdiv(out_w, meta["BLOCK_W"]), + ) + + avg_pool2d_forward_kernel[grid]( + input, + output, + input.stride(0), + input.stride(1), + input.stride(2), + input.stride(3), + in_c, + in_h, + in_w, + out_h, + out_w, + kernel_h, + kernel_w, + stride_h, + stride_w, + padding_h, + padding_w, + dilation_h, + dilation_w, + COUNT_INCLUDE_PAD=count_include_pad, + divisor_override=divisor_override or 0, + ) + + ctx.in_shape = input.shape + ctx.params = ( + kernel_h, + kernel_w, + stride_h, + stride_w, + padding_h, + padding_w, + dilation_h, + dilation_w, + count_include_pad, + divisor_override, + ) + return output + + @staticmethod + def backward(ctx, grad_output): + logger.debug("GEMS AVG_POOL2D BACKWARD") + grad_output = grad_output.contiguous() + in_shape = ctx.in_shape + ( + kernel_h, + kernel_w, + stride_h, + stride_w, + padding_h, + padding_w, + dilation_h, + dilation_w, + count_include_pad, + divisor_override, + ) = ctx.params + + in_n, in_c, in_h, in_w = in_shape + out_h, out_w = grad_output.shape[2], grad_output.shape[3] + + original_dtype = grad_output.dtype + # grad_input must be initialized to zeros as the kernel is not atomic. + grad_input = torch.zeros( + in_shape, device=grad_output.device, dtype=torch.float32 + ) + + if grad_output.numel() > 0: + grid = lambda meta: ( + in_n * in_c, + triton.cdiv(in_h, meta["BLOCK_H"]) * triton.cdiv(in_w, meta["BLOCK_W"]), + ) + + avg_pool2d_backward_kernel[grid]( + grad_output, + grad_input, + in_c, + in_h, + in_w, + out_h, + out_w, + grad_input.stride(0), + grad_input.stride(1), + grad_input.stride(2), + grad_input.stride(3), + grad_output.stride(0), + grad_output.stride(1), + grad_output.stride(2), + grad_output.stride(3), + kernel_h, + kernel_w, + stride_h, + stride_w, + padding_h, + padding_w, + dilation_h, + dilation_w, + COUNT_INCLUDE_PAD=count_include_pad, + divisor_override=divisor_override or 0, + ) + + return grad_input.to(original_dtype), None, None, None, None, None, None + + +def avg_pool2d( + self, + kernel_size, + stride=None, + padding=0, + ceil_mode=False, + count_include_pad=True, + divisor_override=None, +): + return AvgPool2d.apply( + self, + kernel_size, + stride, + padding, + ceil_mode, + count_include_pad, + divisor_override, + ) diff --git a/tests/test_reduction_ops.py b/tests/test_reduction_ops.py index a7150e923..107ddff1a 100644 --- a/tests/test_reduction_ops.py +++ b/tests/test_reduction_ops.py @@ -1108,6 +1108,60 @@ def test_accuracy_masked_select(shape, dtype, threshold): gems_assert_equal(res_out, ref_out) +AVGPOOL2D_CONFIGS = [ + # 3x3 kernel, stride 2, padding 1 + ((4, 3, 32, 32), 3, 2, 1, False, True, None), + # Test count_include_pad=False + ((4, 3, 32, 32), 3, 2, 1, False, False, None), + # Non-square kernel and stride + ((8, 16, 28, 28), (3, 5), (1, 2), 1, False, True, None), + # Test ceil_mode + ((2, 4, 15, 15), 3, 2, 1, True, True, None), + # Test divisor_override + ((1, 1, 7, 7), 2, 1, 0, False, True, 3), + # Larger case from a typical CNN + ((1, 64, 56, 56), 3, 2, 1, False, True, None), + # No padding, count_include_pad=False + ((2, 8, 16, 16), 2, 2, 0, False, False, None), + # Non-square padding + ((2, 8, 16, 20), 2, 2, (1, 0), False, True, None), +] + + +@pytest.mark.avg_pool2d +@pytest.mark.parametrize("shape, kernel_size, stride, padding, ceil_mode, count_include_pad, divisor_override", AVGPOOL2D_CONFIGS) +@pytest.mark.parametrize("dtype", FLOAT_DTYPES) +def test_accuracy_avg_pool2d(shape, kernel_size, stride, padding, ceil_mode, count_include_pad, divisor_override, dtype): + inp = torch.randn(shape, dtype=dtype, device=flag_gems.device, requires_grad=True) + ref_inp = to_reference(inp, True) + + ref_out = torch.nn.functional.avg_pool2d( + ref_inp, + kernel_size=kernel_size, + stride=stride, + padding=padding, + ceil_mode=ceil_mode, + count_include_pad=count_include_pad, + divisor_override=divisor_override, + ) + res_out = flag_gems.avg_pool2d( + inp, + kernel_size=kernel_size, + stride=stride, + padding=padding, + ceil_mode=ceil_mode, + count_include_pad=count_include_pad, + divisor_override=divisor_override, + ) + gems_assert_close(res_out, ref_out, dtype) + + out_grad = torch.randn_like(res_out, device=flag_gems.device) + ref_grad = to_reference(out_grad, True) + (ref_in_grad,) = torch.autograd.grad(ref_out, ref_inp, ref_grad) + (res_in_grad,) = torch.autograd.grad(res_out, inp, out_grad) + gems_assert_close(res_in_grad, ref_in_grad, dtype) + + SHAPE_CONV1D = [ ((32, 2, 4), (17, 2, 2)), ((32, 15, 6), (17, 15, 2)), From 1ccfd134058983bc2024f31e5350a5f41a6cfad1 Mon Sep 17 00:00:00 2001 From: root Date: Sun, 14 Sep 2025 06:27:18 +0000 Subject: [PATCH 2/2] fix issue --- benchmark/test_reduction_perf.py | 13 +++++++------ tests/test_reduction_ops.py | 18 +++++++++++++++--- 2 files changed, 22 insertions(+), 9 deletions(-) diff --git a/benchmark/test_reduction_perf.py b/benchmark/test_reduction_perf.py index 8f321f569..f37302189 100644 --- a/benchmark/test_reduction_perf.py +++ b/benchmark/test_reduction_perf.py @@ -281,14 +281,15 @@ def avg_pool2d_input_fn(shape, dtype, device): "divisor_override": 3, } + class AvgPool2dBenchmark(GenericBenchmark): def get_input_iter(self, cur_dtype) -> Generator: shapes_4d = [ - (4, 3, 224, 224), # Typical input image size - (16, 64, 56, 56), # Early ResNet layer output - (32, 128, 28, 28), # Mid ResNet layer output - (64, 256, 14, 14), # Later ResNet layer output - (128, 512, 7, 7), # Final ResNet layer output + (4, 3, 224, 224), # Typical input image size + (16, 64, 56, 56), # Early ResNet layer output + (32, 128, 28, 28), # Mid ResNet layer output + (64, 256, 14, 14), # Later ResNet layer output + (128, 512, 7, 7), # Final ResNet layer output ] for shape in shapes_4d: @@ -319,7 +320,7 @@ def test_perf_avg_pool2d_backward(): bench.set_gems(flag_gems.avg_pool2d) bench.run() - + @pytest.mark.dot def test_perf_dot(): def dot_input_fn(shape, dtype, device): diff --git a/tests/test_reduction_ops.py b/tests/test_reduction_ops.py index 107ddff1a..2179bfde0 100644 --- a/tests/test_reduction_ops.py +++ b/tests/test_reduction_ops.py @@ -1129,9 +1129,21 @@ def test_accuracy_masked_select(shape, dtype, threshold): @pytest.mark.avg_pool2d -@pytest.mark.parametrize("shape, kernel_size, stride, padding, ceil_mode, count_include_pad, divisor_override", AVGPOOL2D_CONFIGS) +@pytest.mark.parametrize( + "shape, kernel_size, stride, padding, ceil_mode, count_include_pad, divisor_override", + AVGPOOL2D_CONFIGS, +) @pytest.mark.parametrize("dtype", FLOAT_DTYPES) -def test_accuracy_avg_pool2d(shape, kernel_size, stride, padding, ceil_mode, count_include_pad, divisor_override, dtype): +def test_accuracy_avg_pool2d( + shape, + kernel_size, + stride, + padding, + ceil_mode, + count_include_pad, + divisor_override, + dtype, +): inp = torch.randn(shape, dtype=dtype, device=flag_gems.device, requires_grad=True) ref_inp = to_reference(inp, True) @@ -1161,7 +1173,7 @@ def test_accuracy_avg_pool2d(shape, kernel_size, stride, padding, ceil_mode, cou (res_in_grad,) = torch.autograd.grad(res_out, inp, out_grad) gems_assert_close(res_in_grad, ref_in_grad, dtype) - + SHAPE_CONV1D = [ ((32, 2, 4), (17, 2, 2)), ((32, 15, 6), (17, 15, 2)),