Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

torch.compile ae.decode #25

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open

torch.compile ae.decode #25

wants to merge 4 commits into from

Conversation

yorickvP
Copy link
Contributor

@yorickvP yorickvP commented Sep 27, 2024

It takes about 80 seconds on my machine to compile this. Makes the encoding step about 50% faster on A5000 (0.3 -> 0.2s), haven't tried H100.

@yorickvP yorickvP requested a review from daanelson September 27, 2024 18:27
@yorickvP yorickvP force-pushed the yorickvp/torch-compile-vae branch from 461db42 to 99cecf1 Compare September 27, 2024 19:01
Copy link
Collaborator

@daanelson daanelson left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is great! you can push to an internal H100 model (just don't leave it running 😄) on Replicate to test perf in prod, good to have solid metrics on that before we merge

predict.py Outdated
@@ -166,12 +167,65 @@ def base_setup(
shared_models=shared_models,
)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit - since these flags are just simple little flags we set setup for dev/schnell predictor, I don't mind adding a separate compile_ae flag

# the order is important:
# torch.compile has to recompile if it makes invalid assumptions
# about the input sizes. Having higher input sizes first makes
# for fewer recompiles.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

any way we can compile once with craftier use of dynamo.mark_dynamic - add a max=192 on dims 2 & 3? I assume you've tried this, curious how it breaks

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I tried max=192, but it didn't have any effect. Setting torch.compile(dynamic=True) makes for one fewer recompile, but I should check the runtime performance of that.

@yorickvP
Copy link
Contributor Author

yorickvP commented Oct 1, 2024

Did some H100 benchmarks.

flux-schnell 1 image, VAE not compiled

  • 30ms prepare
  • 355ms denoise-single-item
  • 117ms vae-decode
  • total: 505ms

flux-schnell 4 images, VAE not compiled

  • 30 ms prepare
  • 4x 355 ms denoise-single-item
  • 3.21s vae-decode
  • total: 4.69s

flux-schnell 4 images, VAE compiled

  • 30ms prepare
  • 4x 355 ms denoise-single-item
  • 152ms vae-decode
  • total: 1.62s

The VAE speed seems reproducible, where the uncompiled VAE spends a lot of time in nchwToNhwcKernel while the compiled version manages to avoid it.

At the same time, I had a cog bug saying output streams failed to drain, crashing the pod instantly, but this seems unrelated to my PR.

@yorickvP yorickvP force-pushed the yorickvp/torch-compile-vae branch from 99cecf1 to 0039a42 Compare October 10, 2024 16:02
@jonluca
Copy link

jonluca commented Oct 17, 2024

Did you figure out what the output streams failed to drain issue was? I'm seeing that in prod with our cog deploy as well

@yorickvP
Copy link
Contributor Author

@jonluca as I understand it, it was a regression in cog and should be fixed when building with 0.9.25 and later.
It was caused by cog replacing stdout/stderr during predictions, but not during setup, causing forked processes to attempt to write to the original stdout/stderr. Should be fixed in replicate/cog#1969 but let me know if it's not!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants