Skip to content

Commit adcbe67

Browse files
authored
[refactor]Scheduler.set_begin_index (#6728)
1 parent ec9840a commit adcbe67

28 files changed

+620
-279
lines changed

src/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py

+2
Original file line numberDiff line numberDiff line change
@@ -789,6 +789,8 @@ def get_timesteps(self, num_inference_steps, strength, device):
789789

790790
t_start = max(num_inference_steps - init_timestep, 0)
791791
timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :]
792+
if hasattr(self.scheduler, "set_begin_index"):
793+
self.scheduler.set_begin_index(t_start * self.scheduler.order)
792794

793795
return timesteps, num_inference_steps - t_start
794796

src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py

+2
Original file line numberDiff line numberDiff line change
@@ -705,6 +705,8 @@ def get_timesteps(self, num_inference_steps, strength, device):
705705

706706
t_start = max(num_inference_steps - init_timestep, 0)
707707
timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :]
708+
if hasattr(self.scheduler, "set_begin_index"):
709+
self.scheduler.set_begin_index(t_start * self.scheduler.order)
708710

709711
return timesteps, num_inference_steps - t_start
710712

src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py

+2
Original file line numberDiff line numberDiff line change
@@ -871,6 +871,8 @@ def get_timesteps(self, num_inference_steps, strength, device):
871871

872872
t_start = max(num_inference_steps - init_timestep, 0)
873873
timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :]
874+
if hasattr(self.scheduler, "set_begin_index"):
875+
self.scheduler.set_begin_index(t_start * self.scheduler.order)
874876

875877
return timesteps, num_inference_steps - t_start
876878

src/diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_cycle_diffusion.py

+2
Original file line numberDiff line numberDiff line change
@@ -566,6 +566,8 @@ def get_timesteps(self, num_inference_steps, strength, device):
566566

567567
t_start = max(num_inference_steps - init_timestep, 0)
568568
timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :]
569+
if hasattr(self.scheduler, "set_begin_index"):
570+
self.scheduler.set_begin_index(t_start * self.scheduler.order)
569571

570572
return timesteps, num_inference_steps - t_start
571573

src/diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_inpaint_legacy.py

+2
Original file line numberDiff line numberDiff line change
@@ -536,6 +536,8 @@ def get_timesteps(self, num_inference_steps, strength, device):
536536

537537
t_start = max(num_inference_steps - init_timestep, 0)
538538
timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :]
539+
if hasattr(self.scheduler, "set_begin_index"):
540+
self.scheduler.set_begin_index(t_start * self.scheduler.order)
539541

540542
return timesteps, num_inference_steps - t_start
541543

src/diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py

+2
Original file line numberDiff line numberDiff line change
@@ -634,6 +634,8 @@ def get_timesteps(self, num_inference_steps, strength, device):
634634

635635
t_start = max(num_inference_steps - init_timestep, 0)
636636
timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :]
637+
if hasattr(self.scheduler, "set_begin_index"):
638+
self.scheduler.set_begin_index(t_start * self.scheduler.order)
637639

638640
return timesteps, num_inference_steps - t_start
639641

src/diffusers/pipelines/pia/pipeline_pia.py

+2
Original file line numberDiff line numberDiff line change
@@ -906,6 +906,8 @@ def get_timesteps(self, num_inference_steps, strength, device):
906906

907907
t_start = max(num_inference_steps - init_timestep, 0)
908908
timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :]
909+
if hasattr(self.scheduler, "set_begin_index"):
910+
self.scheduler.set_begin_index(t_start * self.scheduler.order)
909911

910912
return timesteps, num_inference_steps - t_start
911913

src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py

+2
Original file line numberDiff line numberDiff line change
@@ -467,6 +467,8 @@ def get_timesteps(self, num_inference_steps, strength, device):
467467

468468
t_start = max(num_inference_steps - init_timestep, 0)
469469
timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :]
470+
if hasattr(self.scheduler, "set_begin_index"):
471+
self.scheduler.set_begin_index(t_start * self.scheduler.order)
470472

471473
return timesteps, num_inference_steps - t_start
472474

src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py

+2
Original file line numberDiff line numberDiff line change
@@ -659,6 +659,8 @@ def get_timesteps(self, num_inference_steps, strength, device):
659659

660660
t_start = max(num_inference_steps - init_timestep, 0)
661661
timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :]
662+
if hasattr(self.scheduler, "set_begin_index"):
663+
self.scheduler.set_begin_index(t_start * self.scheduler.order)
662664

663665
return timesteps, num_inference_steps - t_start
664666

src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py

+2
Original file line numberDiff line numberDiff line change
@@ -859,6 +859,8 @@ def get_timesteps(self, num_inference_steps, strength, device):
859859

860860
t_start = max(num_inference_steps - init_timestep, 0)
861861
timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :]
862+
if hasattr(self.scheduler, "set_begin_index"):
863+
self.scheduler.set_begin_index(t_start * self.scheduler.order)
862864

863865
return timesteps, num_inference_steps - t_start
864866

src/diffusers/pipelines/stable_diffusion_diffedit/pipeline_stable_diffusion_diffedit.py

+2
Original file line numberDiff line numberDiff line change
@@ -754,6 +754,8 @@ def get_timesteps(self, num_inference_steps, strength, device):
754754

755755
t_start = max(num_inference_steps - init_timestep, 0)
756756
timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :]
757+
if hasattr(self.scheduler, "set_begin_index"):
758+
self.scheduler.set_begin_index(t_start * self.scheduler.order)
757759

758760
return timesteps, num_inference_steps - t_start
759761

src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py

+2
Original file line numberDiff line numberDiff line change
@@ -554,6 +554,8 @@ def get_timesteps(self, num_inference_steps, strength, device):
554554

555555
t_start = max(num_inference_steps - init_timestep, 0)
556556
timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :]
557+
if hasattr(self.scheduler, "set_begin_index"):
558+
self.scheduler.set_begin_index(t_start * self.scheduler.order)
557559

558560
return timesteps, num_inference_steps - t_start
559561

src/diffusers/schedulers/scheduling_consistency_models.py

+41-18
Original file line numberDiff line numberDiff line change
@@ -98,22 +98,34 @@ def __init__(
9898
self.custom_timesteps = False
9999
self.is_scale_input_called = False
100100
self._step_index = None
101+
self._begin_index = None
101102
self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
102103

103-
def index_for_timestep(self, timestep, schedule_timesteps=None):
104-
if schedule_timesteps is None:
105-
schedule_timesteps = self.timesteps
106-
107-
indices = (schedule_timesteps == timestep).nonzero()
108-
return indices.item()
109-
110104
@property
111105
def step_index(self):
112106
"""
113107
The index counter for current timestep. It will increae 1 after each scheduler step.
114108
"""
115109
return self._step_index
116110

111+
@property
112+
def begin_index(self):
113+
"""
114+
The index for the first timestep. It should be set from pipeline with `set_begin_index` method.
115+
"""
116+
return self._begin_index
117+
118+
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index
119+
def set_begin_index(self, begin_index: int = 0):
120+
"""
121+
Sets the begin index for the scheduler. This function should be run from pipeline before the inference.
122+
123+
Args:
124+
begin_index (`int`):
125+
The begin index for the scheduler.
126+
"""
127+
self._begin_index = begin_index
128+
117129
def scale_model_input(
118130
self, sample: torch.FloatTensor, timestep: Union[float, torch.FloatTensor]
119131
) -> torch.FloatTensor:
@@ -231,6 +243,7 @@ def set_timesteps(
231243
self.timesteps = torch.from_numpy(timesteps).to(device=device)
232244

233245
self._step_index = None
246+
self._begin_index = None
234247
self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
235248

236249
# Modified _convert_to_karras implementation that takes in ramp as argument
@@ -280,23 +293,29 @@ def get_scalings_for_boundary_condition(self, sigma):
280293
c_out = (sigma - sigma_min) * sigma_data / (sigma**2 + sigma_data**2) ** 0.5
281294
return c_skip, c_out
282295

283-
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._init_step_index
284-
def _init_step_index(self, timestep):
285-
if isinstance(timestep, torch.Tensor):
286-
timestep = timestep.to(self.timesteps.device)
296+
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler.index_for_timestep
297+
def index_for_timestep(self, timestep, schedule_timesteps=None):
298+
if schedule_timesteps is None:
299+
schedule_timesteps = self.timesteps
287300

288-
index_candidates = (self.timesteps == timestep).nonzero()
301+
indices = (schedule_timesteps == timestep).nonzero()
289302

290303
# The sigma index that is taken for the **very** first `step`
291304
# is always the second index (or the last index if there is only 1)
292305
# This way we can ensure we don't accidentally skip a sigma in
293306
# case we start in the middle of the denoising schedule (e.g. for image-to-image)
294-
if len(index_candidates) > 1:
295-
step_index = index_candidates[1]
296-
else:
297-
step_index = index_candidates[0]
307+
pos = 1 if len(indices) > 1 else 0
308+
309+
return indices[pos].item()
298310

299-
self._step_index = step_index.item()
311+
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._init_step_index
312+
def _init_step_index(self, timestep):
313+
if self.begin_index is None:
314+
if isinstance(timestep, torch.Tensor):
315+
timestep = timestep.to(self.timesteps.device)
316+
self._step_index = self.index_for_timestep(timestep)
317+
else:
318+
self._step_index = self._begin_index
300319

301320
def step(
302321
self,
@@ -412,7 +431,11 @@ def add_noise(
412431
schedule_timesteps = self.timesteps.to(original_samples.device)
413432
timesteps = timesteps.to(original_samples.device)
414433

415-
step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps]
434+
# self.begin_index is None when scheduler is used for training, or pipeline does not implement set_begin_index
435+
if self.begin_index is None:
436+
step_indices = [self.index_for_timestep(t, schedule_timesteps) for t in timesteps]
437+
else:
438+
step_indices = [self.begin_index] * timesteps.shape[0]
416439

417440
sigma = sigmas[step_indices].flatten()
418441
while len(sigma.shape) < len(original_samples.shape):

src/diffusers/schedulers/scheduling_deis_multistep.py

+44-15
Original file line numberDiff line numberDiff line change
@@ -187,6 +187,7 @@ def __init__(
187187
self.model_outputs = [None] * solver_order
188188
self.lower_order_nums = 0
189189
self._step_index = None
190+
self._begin_index = None
190191
self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
191192

192193
@property
@@ -196,6 +197,24 @@ def step_index(self):
196197
"""
197198
return self._step_index
198199

200+
@property
201+
def begin_index(self):
202+
"""
203+
The index for the first timestep. It should be set from pipeline with `set_begin_index` method.
204+
"""
205+
return self._begin_index
206+
207+
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index
208+
def set_begin_index(self, begin_index: int = 0):
209+
"""
210+
Sets the begin index for the scheduler. This function should be run from pipeline before the inference.
211+
212+
Args:
213+
begin_index (`int`):
214+
The begin index for the scheduler.
215+
"""
216+
self._begin_index = begin_index
217+
199218
def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None):
200219
"""
201220
Sets the discrete timesteps used for the diffusion chain (to be run before inference).
@@ -255,6 +274,7 @@ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.devic
255274

256275
# add an index counter for schedulers that allow duplicated timesteps
257276
self._step_index = None
277+
self._begin_index = None
258278
self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
259279

260280
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample
@@ -620,11 +640,12 @@ def ind_fn(t, b, c, d):
620640
else:
621641
raise NotImplementedError("only support log-rho multistep deis now")
622642

623-
def _init_step_index(self, timestep):
624-
if isinstance(timestep, torch.Tensor):
625-
timestep = timestep.to(self.timesteps.device)
643+
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.index_for_timestep
644+
def index_for_timestep(self, timestep, schedule_timesteps=None):
645+
if schedule_timesteps is None:
646+
schedule_timesteps = self.timesteps
626647

627-
index_candidates = (self.timesteps == timestep).nonzero()
648+
index_candidates = (schedule_timesteps == timestep).nonzero()
628649

629650
if len(index_candidates) == 0:
630651
step_index = len(self.timesteps) - 1
@@ -637,7 +658,20 @@ def _init_step_index(self, timestep):
637658
else:
638659
step_index = index_candidates[0].item()
639660

640-
self._step_index = step_index
661+
return step_index
662+
663+
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler._init_step_index
664+
def _init_step_index(self, timestep):
665+
"""
666+
Initialize the step_index counter for the scheduler.
667+
"""
668+
669+
if self.begin_index is None:
670+
if isinstance(timestep, torch.Tensor):
671+
timestep = timestep.to(self.timesteps.device)
672+
self._step_index = self.index_for_timestep(timestep)
673+
else:
674+
self._step_index = self._begin_index
641675

642676
def step(
643677
self,
@@ -736,16 +770,11 @@ def add_noise(
736770
schedule_timesteps = self.timesteps.to(original_samples.device)
737771
timesteps = timesteps.to(original_samples.device)
738772

739-
step_indices = []
740-
for timestep in timesteps:
741-
index_candidates = (schedule_timesteps == timestep).nonzero()
742-
if len(index_candidates) == 0:
743-
step_index = len(schedule_timesteps) - 1
744-
elif len(index_candidates) > 1:
745-
step_index = index_candidates[1].item()
746-
else:
747-
step_index = index_candidates[0].item()
748-
step_indices.append(step_index)
773+
# begin_index is None when the scheduler is used for training
774+
if self.begin_index is None:
775+
step_indices = [self.index_for_timestep(t, schedule_timesteps) for t in timesteps]
776+
else:
777+
step_indices = [self.begin_index] * timesteps.shape[0]
749778

750779
sigma = sigmas[step_indices].flatten()
751780
while len(sigma.shape) < len(original_samples.shape):

0 commit comments

Comments
 (0)