Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
81 changes: 81 additions & 0 deletions benchmark/test_reduction_perf.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,6 +240,87 @@ def count_nonzero_input_fn(shape, dtype, device):
bench.run()


def avg_pool2d_input_fn(shape, dtype, device):
inp = generate_tensor_input(shape, dtype, device)
# Common case
yield inp, {
"kernel_size": 3,
"stride": 2,
"padding": 1,
"ceil_mode": False,
"count_include_pad": True,
"divisor_override": None,
}
if Config.bench_level == BenchLevel.COMPREHENSIVE:
# With count_include_pad=False
yield inp, {
"kernel_size": 3,
"stride": 2,
"padding": 1,
"ceil_mode": False,
"count_include_pad": False,
"divisor_override": None,
}
# With ceil_mode
yield inp, {
"kernel_size": 3,
"stride": 2,
"padding": 1,
"ceil_mode": True,
"count_include_pad": True,
"divisor_override": None,
}
# With divisor_override
if shape[-2] >= 2 and shape[-1] >= 2:
yield inp, {
"kernel_size": 2,
"stride": 1,
"padding": 0,
"ceil_mode": False,
"count_include_pad": True,
"divisor_override": 3,
}


class AvgPool2dBenchmark(GenericBenchmark):
def get_input_iter(self, cur_dtype) -> Generator:
shapes_4d = [
(4, 3, 224, 224), # Typical input image size
(16, 64, 56, 56), # Early ResNet layer output
(32, 128, 28, 28), # Mid ResNet layer output
(64, 256, 14, 14), # Later ResNet layer output
(128, 512, 7, 7), # Final ResNet layer output
]

for shape in shapes_4d:
yield from self.input_fn(shape, cur_dtype, self.device)
Comment on lines +285 to +296
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The get_input_iter method is identical in AvgPool2dBenchmark and MaxPool2dBenchmark (lines 392-403). You can refactor this to reduce code duplication. For instance, you could create a common base class for pooling benchmarks that contains this method.

Example:

class Pool2dBenchmark(GenericBenchmark):
    def get_input_iter(self, cur_dtype) -> Generator:
        shapes_4d = [
            (4, 3, 224, 224),  # Typical input image size
            (16, 64, 56, 56),  # Early ResNet layer output
            (32, 128, 28, 28),  # Mid ResNet layer output
            (64, 256, 14, 14),  # Later ResNet layer output
            (128, 512, 7, 7),  # Final ResNet layer output
        ]

        for shape in shapes_4d:
            yield from self.input_fn(shape, cur_dtype, self.device)

class AvgPool2dBenchmark(Pool2dBenchmark):
    pass

class MaxPool2dBenchmark(Pool2dBenchmark):
    pass

Then you would initialize AvgPool2dBenchmark and MaxPool2dBenchmark in your tests as before, but they would inherit the common get_input_iter.



@pytest.mark.avg_pool2d
def test_perf_avg_pool2d():
bench = AvgPool2dBenchmark(
input_fn=avg_pool2d_input_fn,
op_name="avg_pool2d",
torch_op=torch.nn.functional.avg_pool2d,
dtypes=FLOAT_DTYPES,
)
bench.set_gems(flag_gems.avg_pool2d)
bench.run()


@pytest.mark.avg_pool2d_backward
def test_perf_avg_pool2d_backward():
bench = AvgPool2dBenchmark(
input_fn=avg_pool2d_input_fn,
op_name="avg_pool2d",
torch_op=torch.nn.functional.avg_pool2d,
dtypes=FLOAT_DTYPES,
is_backward=True,
)
bench.set_gems(flag_gems.avg_pool2d)
bench.run()


@pytest.mark.dot
def test_perf_dot():
def dot_input_fn(shape, dtype, device):
Expand Down
1 change: 1 addition & 0 deletions src/flag_gems/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ def enable(
("arange.start_step", arange_start),
("argmax", argmax),
("argmin", argmin),
("avg_pool2d", avg_pool2d),
("bitwise_and.Scalar", bitwise_and_scalar),
("bitwise_and.Scalar_Tensor", bitwise_and_scalar_tensor),
("bitwise_and.Tensor", bitwise_and_tensor),
Expand Down
2 changes: 2 additions & 0 deletions src/flag_gems/ops/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
flash_attn_varlen_func,
scaled_dot_product_attention,
)
from flag_gems.ops.avg_pool2d import avg_pool2d
from flag_gems.ops.batch_norm import batch_norm, batch_norm_backward
from flag_gems.ops.bitwise_and import (
bitwise_and_scalar,
Expand Down Expand Up @@ -226,6 +227,7 @@
"arange_start",
"argmax",
"argmin",
"avg_pool2d",
"batch_norm",
"batch_norm_backward",
"bitwise_and_scalar",
Expand Down
Loading
Loading