Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Assertion when lowering from Torch IR for AvgPool2d when kernel is a tuple of a single int #3885

Open
sahas3 opened this issue Nov 21, 2024 · 0 comments

Comments

@sahas3
Copy link
Contributor

sahas3 commented Nov 21, 2024

ExportedProgram for

class AvgPool2dFloatModule(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.ap2d = torch.nn.AvgPool2d(
            kernel_size=6,
        )

    def forward(self, x):
        return self.ap2d(x)

produces the call to AvgPool2d as torch.ops.aten.avg_pool2d.default(x, [6, 6], [6, 6]). This matches with documented behavior for kernel parameter (https://pytorch.org/docs/stable/generated/torch.nn.AvgPool2d.html) that states that single integer value will be used for both height, width dimension. As per documentation, the only other possible value for kernel is a tuple of 2 integers. However, tuple of single element works as well:

class AvgPool2dFloatModule(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.ap2d = torch.nn.AvgPool2d(
            kernel_size=(6,)
        )

    def forward(self, x):
        return self.ap2d(x)

and the ExportedProgram has the call to AvgPool2d as torch.ops.aten.avg_pool2d.default(x, [6], [6]). Note that the kernel value is not being repeated though that's what happens when executing the code in python.

This ExportedProgram causes an assertion when lowering the resulting Torch IR to Tosa/Linalg/Stablehlo as the lowerings assume that kernel is 2-elements.

So I think this can be fixed by either of the following approaches:

  1. Match the behavior of ExportedProgram for the second scenario to match with the first one. I am not familiar with PyTorch codebase, so not sure where to make the change. If anyone knows where to start looking, I'll appreciate it.
  2. Fix the individual lowerings but that means repeating the same logic in 3 different places.
  3. In Torch IR before any of the lowerings (possibly when DecomposeComplexOps is called) extend the kernel param of the torch.aten.avg_pool2d op to be of correct size, so the individual lowerings don't need to be fixed.

I'm leaning towards 3 (since I don't know how to make 1 work) -- is that the correct approach? If so, which pass will be the correct place to add the logic -- AFAICT none of the existing passes seem to be doing a similar transform where the op is replaced by the same op but with different params. Should I add a new pass?

@sjarus, @vivekkhandelwal1 -- any thoughts? Thanks!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant