Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
92 changes: 92 additions & 0 deletions benchmark/test_reduction_perf.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,6 +240,98 @@ def count_nonzero_input_fn(shape, dtype, device):
bench.run()


def max_pool2d_input_fn(shape, dtype, device):
inp = generate_tensor_input(shape, dtype, device)
yield inp, {
"kernel_size": 3,
"stride": 2,
"padding": 1,
"dilation": 1,
"ceil_mode": False,
}
if Config.bench_level == BenchLevel.COMPREHENSIVE:
# Non-square kernel/stride/padding
if shape[-2] > 5 and shape[-1] > 5:
yield inp, {
"kernel_size": (3, 5),
"stride": (2, 1),
"padding": (1, 2),
"dilation": 1,
"ceil_mode": False,
}
# With dilation
yield inp, {
"kernel_size": 3,
"stride": 1,
"padding": 1,
"dilation": 2,
"ceil_mode": False,
}
# With ceil_mode
yield inp, {
"kernel_size": 3,
"stride": 2,
"padding": 1,
"dilation": 1,
"ceil_mode": True,
}


class MaxPool2dBenchmark(GenericBenchmark):
def get_input_iter(self, cur_dtype) -> Generator:
shapes_4d = [
(4, 3, 224, 224), # Typical input image size
(16, 64, 56, 56), # Early ResNet layer output
(32, 128, 28, 28), # Mid ResNet layer output
(64, 256, 14, 14), # Later ResNet layer output
(128, 512, 7, 7), # Final ResNet layer output
]

for shape in shapes_4d:
yield from self.input_fn(shape, cur_dtype, self.device)


@pytest.mark.max_pool2d
def test_perf_max_pool2d():
bench = MaxPool2dBenchmark(
input_fn=max_pool2d_input_fn,
op_name="max_pool2d_with_indices",
torch_op=torch.nn.functional.max_pool2d_with_indices,
dtypes=FLOAT_DTYPES,
)
bench.set_gems(flag_gems.max_pool2d_with_indices)
bench.run()


@pytest.mark.max_pool2d_backward
def test_perf_max_pool2d_backward():
def max_pool2d_backward_input_fn(shape, dtype, device):
for forward_args in max_pool2d_input_fn(shape, dtype, device):
inp, params = forward_args
inp.requires_grad_(True)
output, indices = torch.nn.functional.max_pool2d_with_indices(inp, **params)
grad_output = torch.randn_like(output)
yield grad_output, inp, indices, params

def torch_max_pool2d_backward_wrapper(grad_output, input, indices, **kwargs):
output, _ = torch.nn.functional.max_pool2d_with_indices(input, **kwargs)
grad_input = torch.autograd.grad(
outputs=(output,), inputs=(input,), grad_outputs=(grad_output,)
)
return grad_input[0]

bench = MaxPool2dBenchmark(
input_fn=max_pool2d_backward_input_fn,
op_name="max_pool2d_backward",
torch_op=torch_max_pool2d_backward_wrapper,
dtypes=FLOAT_DTYPES,
is_backward=False,
)

bench.set_gems(flag_gems.max_pool2d_backward)
bench.run()


@pytest.mark.dot
def test_perf_dot():
def dot_input_fn(shape, dtype, device):
Expand Down
2 changes: 2 additions & 0 deletions src/flag_gems/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,8 @@ def enable(
("max", max),
("max.dim", max_dim),
("maximum", maximum),
("max_pool2d_with_indices", max_pool2d_with_indices),
("max_pool2d_backward", max_pool2d_backward),
("mean", mean),
("mean.dim", mean_dim),
("min", min),
Expand Down
6 changes: 6 additions & 0 deletions src/flag_gems/ops/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,10 @@
from flag_gems.ops.masked_fill import masked_fill, masked_fill_
from flag_gems.ops.masked_select import masked_select
from flag_gems.ops.max import max, max_dim
from flag_gems.ops.max_pool2d_with_indices import (
max_pool2d_backward,
max_pool2d_with_indices,
)
from flag_gems.ops.maximum import maximum
from flag_gems.ops.mean import mean, mean_dim
from flag_gems.ops.min import min, min_dim
Expand Down Expand Up @@ -347,6 +351,8 @@
"max",
"max_dim",
"maximum",
"max_pool2d_with_indices",
"max_pool2d_backward",
"mean",
"mean_dim",
"min",
Expand Down
Loading
Loading