@@ -187,6 +187,7 @@ def __init__(
187
187
self .model_outputs = [None ] * solver_order
188
188
self .lower_order_nums = 0
189
189
self ._step_index = None
190
+ self ._begin_index = None
190
191
self .sigmas = self .sigmas .to ("cpu" ) # to avoid too much CPU/GPU communication
191
192
192
193
@property
@@ -196,6 +197,24 @@ def step_index(self):
196
197
"""
197
198
return self ._step_index
198
199
200
+ @property
201
+ def begin_index (self ):
202
+ """
203
+ The index for the first timestep. It should be set from pipeline with `set_begin_index` method.
204
+ """
205
+ return self ._begin_index
206
+
207
+ # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index
208
+ def set_begin_index (self , begin_index : int = 0 ):
209
+ """
210
+ Sets the begin index for the scheduler. This function should be run from pipeline before the inference.
211
+
212
+ Args:
213
+ begin_index (`int`):
214
+ The begin index for the scheduler.
215
+ """
216
+ self ._begin_index = begin_index
217
+
199
218
def set_timesteps (self , num_inference_steps : int , device : Union [str , torch .device ] = None ):
200
219
"""
201
220
Sets the discrete timesteps used for the diffusion chain (to be run before inference).
@@ -255,6 +274,7 @@ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.devic
255
274
256
275
# add an index counter for schedulers that allow duplicated timesteps
257
276
self ._step_index = None
277
+ self ._begin_index = None
258
278
self .sigmas = self .sigmas .to ("cpu" ) # to avoid too much CPU/GPU communication
259
279
260
280
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample
@@ -620,11 +640,12 @@ def ind_fn(t, b, c, d):
620
640
else :
621
641
raise NotImplementedError ("only support log-rho multistep deis now" )
622
642
623
- def _init_step_index (self , timestep ):
624
- if isinstance (timestep , torch .Tensor ):
625
- timestep = timestep .to (self .timesteps .device )
643
+ # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.index_for_timestep
644
+ def index_for_timestep (self , timestep , schedule_timesteps = None ):
645
+ if schedule_timesteps is None :
646
+ schedule_timesteps = self .timesteps
626
647
627
- index_candidates = (self . timesteps == timestep ).nonzero ()
648
+ index_candidates = (schedule_timesteps == timestep ).nonzero ()
628
649
629
650
if len (index_candidates ) == 0 :
630
651
step_index = len (self .timesteps ) - 1
@@ -637,7 +658,20 @@ def _init_step_index(self, timestep):
637
658
else :
638
659
step_index = index_candidates [0 ].item ()
639
660
640
- self ._step_index = step_index
661
+ return step_index
662
+
663
+ # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler._init_step_index
664
+ def _init_step_index (self , timestep ):
665
+ """
666
+ Initialize the step_index counter for the scheduler.
667
+ """
668
+
669
+ if self .begin_index is None :
670
+ if isinstance (timestep , torch .Tensor ):
671
+ timestep = timestep .to (self .timesteps .device )
672
+ self ._step_index = self .index_for_timestep (timestep )
673
+ else :
674
+ self ._step_index = self ._begin_index
641
675
642
676
def step (
643
677
self ,
@@ -736,16 +770,11 @@ def add_noise(
736
770
schedule_timesteps = self .timesteps .to (original_samples .device )
737
771
timesteps = timesteps .to (original_samples .device )
738
772
739
- step_indices = []
740
- for timestep in timesteps :
741
- index_candidates = (schedule_timesteps == timestep ).nonzero ()
742
- if len (index_candidates ) == 0 :
743
- step_index = len (schedule_timesteps ) - 1
744
- elif len (index_candidates ) > 1 :
745
- step_index = index_candidates [1 ].item ()
746
- else :
747
- step_index = index_candidates [0 ].item ()
748
- step_indices .append (step_index )
773
+ # begin_index is None when the scheduler is used for training
774
+ if self .begin_index is None :
775
+ step_indices = [self .index_for_timestep (t , schedule_timesteps ) for t in timesteps ]
776
+ else :
777
+ step_indices = [self .begin_index ] * timesteps .shape [0 ]
749
778
750
779
sigma = sigmas [step_indices ].flatten ()
751
780
while len (sigma .shape ) < len (original_samples .shape ):
0 commit comments