Skip to content

Commit c3369f5

Browse files
fix torchvision import (#6796)
1 parent 04cd6ad commit c3369f5

File tree

1 file changed

+9
-1
lines changed

1 file changed

+9
-1
lines changed

src/diffusers/training_utils.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
import numpy as np
77
import torch
8-
from torchvision import transforms
8+
from transformers import is_torchvision_available
99

1010
from .models import UNet2DConditionModel
1111
from .utils import (
@@ -23,6 +23,9 @@
2323
if is_peft_available():
2424
from peft import set_peft_model_state_dict
2525

26+
if is_torchvision_available():
27+
from torchvision import transforms
28+
2629

2730
def set_seed(seed: int):
2831
"""
@@ -79,6 +82,11 @@ def resolve_interpolation_mode(interpolation_type: str):
7982
`torchvision.transforms.InterpolationMode`: an `InterpolationMode` enum used by torchvision's `resize`
8083
transform.
8184
"""
85+
if not is_torchvision_available():
86+
raise ImportError(
87+
"Please make sure to install `torchvision` to be able to use the `resolve_interpolation_mode()` function."
88+
)
89+
8290
if interpolation_type == "bilinear":
8391
interpolation_mode = transforms.InterpolationMode.BILINEAR
8492
elif interpolation_type == "bicubic":

0 commit comments

Comments
 (0)