-
Notifications
You must be signed in to change notification settings - Fork 6.4k
Avoiding graph break by changing the way we infer dtype in vae.decoder #12512
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Avoiding graph break by changing the way we infer dtype in vae.decoder #12512
Conversation
@DN6 WDYT? |
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
I made sure all autoencoder tests are passing locally, I would be very thankful if you can take a look @DN6 |
sample = self.conv_in(sample) | ||
|
||
upscale_dtype = next(iter(self.up_blocks.parameters())).dtype | ||
upscale_dtype = self.up_blocks[0].resnets[0].norm1.weight.dtype |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Think current failing tests in the CI are due to the fact that not every decoder block has a norm1 with a weight. Hence the use of the generator here to avoid such cases.
@ppadjinTT I noticed you initially used self.conv_out.weight
here? What was the issue you ran into with that?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
okay, I will change that too, tnx! I intially changed the self.conv_out.weight
because there are some tests that check what happens when conv_out and upscale_blocks have different dtypes
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you point me to those tests? Seems like setting to conv_out is more robust.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yup, these are the tests pytest -svvv tests/models/autoencoders/test_models_autoencoder_kl.py
This is one of the tests from this test set that fails tests/models/autoencoders/test_models_autoencoder_kl.py::AutoencoderKLTests::test_layerwise_casting_inference
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I added better logic for inferring dtype, to capture the case where it doesn't work
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmm I think we can remove upscale_type
entirely here. I think all tests should still pass without it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
okay let's try that, i'm pushing the change
…sors
What does this PR do?
This PR addresses the problem disscused in #12501, where the usage of
upscale_dtype = next(iter(self.up_blocks.parameters())).dtype
to infer the dtype in the forward pass of thevae.decoder
causes the graph break when compiling the model with torch.compile.The issue is that the usage of
next(iter(...))
forces the lazy tensors in the initial compiled model pass to materialize, resulting in graph break, which decreases performance.This PR proposes a simple fix by infering the
dtype
as:Fixes #12501
Who can review?
@sayakpaul