Description
Describe the bug
I'm trying to optimize my data preprocessing pipeline for the Sana model by using torch.compile
on the DC-AE encoder. Following PyTorch's best practices, I attempted to compile only the encode
method with fullgraph=True
for better performance, but I'm encountering an error.
When I try:
dae.encode = torch.compile(dae.encode, fullgraph=True)
The code fails with NameError: name 'torch' is not defined
when calling dae.encode(x)
.
However, compiling the entire model works:
dae = torch.compile(dae, fullgraph=True)
I'm unsure if this is expected behavior or if I'm doing something wrong. Is there a recommended way to compile just the encode method for AutoencoderDC
?
I was advised to use the more targeted approach of compiling only the encode method for better performance, but it seems like the DC-AE model might have some internal structure that prevents this optimization pattern.
Any guidance on the correct way to apply torch.compile
optimizations to AutoencoderDC
would be greatly appreciated. Should I stick with compiling the entire model, or is there a way to make method-level compilation work?
Reproduction
import torch
from diffusers import AutoencoderDC
# Load model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
dae = AutoencoderDC.from_pretrained(
"mit-han-lab/dc-ae-f32c32-sana-1.1-diffusers",
torch_dtype=torch.bfloat16
).to(device).eval()
# This fails with "name 'torch' is not defined"
dae.encode = torch.compile(dae.encode, fullgraph=True)
# Test
x = torch.randn(1, 3, 512, 512, device=device, dtype=torch.bfloat16)
out = dae.encode(x) # Error occurs here
# This works fine
dae = torch.compile(dae, fullgraph=True)
Logs
Testing torch.compile(dae.encode, fullgraph=True)
/data1/tzz/anaconda_dir/envs/Sana/lib/python3.10/site-packages/torch/_inductor/compile_fx.py:150: UserWarning: TensorFloat32 tensor cores for float32 matrix multiplication available but not enabled. Consider setting `torch.set_float32_matmul_precision('high')` for better performance.
warnings.warn(
✗ Error: name 'torch' is not defined
System Info
- 🤗 Diffusers version: 0.34.0.dev0
- Platform: Linux-5.15.0-142-generic-x86_64-with-glibc2.35
- Running on Google Colab?: No
- Python version: 3.10.18
- PyTorch version (GPU?): 2.4.0+cu121 (True)
- Flax version (CPU?/GPU?/TPU?): not installed (NA)
- Jax version: not installed
- JaxLib version: not installed
- Huggingface_hub version: 0.33.0
- Transformers version: 4.45.2
- Accelerate version: 1.7.0
- PEFT version: 0.15.2
- Bitsandbytes version: 0.46.0
- Safetensors version: 0.5.3
- xFormers version: 0.0.27.post2
- Accelerator: NVIDIA A100-SXM4-80GB, 81920 MiB
NVIDIA A100-SXM4-80GB, 81920 MiB
NVIDIA A100-SXM4-80GB, 81920 MiB
NVIDIA A100-SXM4-80GB, 81920 MiB
NVIDIA A100-SXM4-80GB, 81920 MiB
NVIDIA A100-SXM4-80GB, 81920 MiB
NVIDIA A100-SXM4-80GB, 81920 MiB - Using GPU in script?: yes
- Using distributed or parallel set-up in script?: no
Who can help?
No response