File tree 1 file changed +9
-1
lines changed 1 file changed +9
-1
lines changed Original file line number Diff line number Diff line change 5
5
6
6
import numpy as np
7
7
import torch
8
- from torchvision import transforms
8
+ from transformers import is_torchvision_available
9
9
10
10
from .models import UNet2DConditionModel
11
11
from .utils import (
23
23
if is_peft_available ():
24
24
from peft import set_peft_model_state_dict
25
25
26
+ if is_torchvision_available ():
27
+ from torchvision import transforms
28
+
26
29
27
30
def set_seed (seed : int ):
28
31
"""
@@ -79,6 +82,11 @@ def resolve_interpolation_mode(interpolation_type: str):
79
82
`torchvision.transforms.InterpolationMode`: an `InterpolationMode` enum used by torchvision's `resize`
80
83
transform.
81
84
"""
85
+ if not is_torchvision_available ():
86
+ raise ImportError (
87
+ "Please make sure to install `torchvision` to be able to use the `resolve_interpolation_mode()` function."
88
+ )
89
+
82
90
if interpolation_type == "bilinear" :
83
91
interpolation_mode = transforms .InterpolationMode .BILINEAR
84
92
elif interpolation_type == "bicubic" :
You can’t perform that action at this time.
0 commit comments