Skip to content

Conversation

ppadjinTT
Copy link

…sors

What does this PR do?

This PR addresses the problem disscused in #12501, where the usage of upscale_dtype = next(iter(self.up_blocks.parameters())).dtype to infer the dtype in the forward pass of the vae.decoder causes the graph break when compiling the model with torch.compile.

The issue is that the usage of next(iter(...)) forces the lazy tensors in the initial compiled model pass to materialize, resulting in graph break, which decreases performance.

This PR proposes a simple fix by infering the dtype as:

upscale_dtype = self.conv_out.weight.dtype

Fixes #12501

Who can review?

@sayakpaul

@sayakpaul
Copy link
Member

@DN6 WDYT?

@sayakpaul sayakpaul requested a review from DN6 October 20, 2025 16:29
@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@ppadjinTT
Copy link
Author

ppadjinTT commented Oct 23, 2025

I made sure all autoencoder tests are passing locally, I would be very thankful if you can take a look @DN6

sample = self.conv_in(sample)

upscale_dtype = next(iter(self.up_blocks.parameters())).dtype
upscale_dtype = self.up_blocks[0].resnets[0].norm1.weight.dtype
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Think current failing tests in the CI are due to the fact that not every decoder block has a norm1 with a weight. Hence the use of the generator here to avoid such cases.

@ppadjinTT I noticed you initially used self.conv_out.weight here? What was the issue you ran into with that?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

okay, I will change that too, tnx! I intially changed the self.conv_out.weight because there are some tests that check what happens when conv_out and upscale_blocks have different dtypes

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you point me to those tests? Seems like setting to conv_out is more robust.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yup, these are the tests pytest -svvv tests/models/autoencoders/test_models_autoencoder_kl.py

This is one of the tests from this test set that fails tests/models/autoencoders/test_models_autoencoder_kl.py::AutoencoderKLTests::test_layerwise_casting_inference

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added better logic for inferring dtype, to capture the case where it doesn't work

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm I think we can remove upscale_type entirely here. I think all tests should still pass without it.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

okay let's try that, i'm pushing the change

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

VAE Decoder next(iter(..)) causes graph break

4 participants