Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
127 changes: 88 additions & 39 deletions src/diffusers/schedulers/scheduling_cosine_dpmsolver_multistep.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
# DISCLAIMER: This file is strongly influenced by https://github.com/LuChengTHU/dpm-solver and https://github.com/NVlabs/edm

import math
from typing import List, Optional, Tuple, Union
from typing import List, Literal, Optional, Tuple, Union

import numpy as np
import torch
Expand All @@ -36,27 +36,30 @@ class CosineDPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
methods the library implements for all schedulers such as loading and saving.

Args:
sigma_min (`float`, *optional*, defaults to 0.3):
sigma_min (`float`, defaults to `0.3`):
Minimum noise magnitude in the sigma schedule. This was set to 0.3 in Stable Audio Open [1].
sigma_max (`float`, *optional*, defaults to 500):
sigma_max (`float`, defaults to `500`):
Maximum noise magnitude in the sigma schedule. This was set to 500 in Stable Audio Open [1].
sigma_data (`float`, *optional*, defaults to 1.0):
sigma_data (`float`, defaults to `1.0`):
The standard deviation of the data distribution. This is set to 1.0 in Stable Audio Open [1].
sigma_schedule (`str`, *optional*, defaults to `exponential`):
Sigma schedule to compute the `sigmas`. By default, we the schedule introduced in the EDM paper
(https://huggingface.co/papers/2206.00364). Other acceptable value is "exponential". The exponential
schedule was incorporated in this model: https://huggingface.co/stabilityai/cosxl.
num_train_timesteps (`int`, defaults to 1000):
sigma_schedule (`str`, defaults to `"exponential"`):
Sigma schedule to compute the `sigmas`. Must be one of `"exponential"` or `"karras"`. The exponential
schedule was incorporated in [stabilityai/cosxl](https://huggingface.co/stabilityai/cosxl). The Karras
schedule is introduced in the [EDM](https://huggingface.co/papers/2206.00364) paper.
num_train_timesteps (`int`, defaults to `1000`):
The number of diffusion steps to train the model.
solver_order (`int`, defaults to 2):
solver_order (`int`, defaults to `2`):
The DPMSolver order which can be `1` or `2`. It is recommended to use `solver_order=2`.
prediction_type (`str`, defaults to `v_prediction`, *optional*):
Prediction type of the scheduler function; can be `epsilon` (predicts the noise of the diffusion process),
`sample` (directly predicts the noisy sample`) or `v_prediction` (see section 2.4 of [Imagen
prediction_type (`str`, defaults to `"v_prediction"`):
Prediction type of the scheduler function. Must be one of `"epsilon"` (predicts the noise of the diffusion
process), `"sample"` (directly predicts the noisy sample), or `"v_prediction"` (see section 2.4 of [Imagen
Video](https://huggingface.co/papers/2210.02303) paper).
solver_type (`str`, defaults to `midpoint`):
Solver type for the second-order solver; can be `midpoint` or `heun`. The solver type slightly affects the
sample quality, especially for a small number of steps. It is recommended to use `midpoint` solvers.
rho (`float`, defaults to `7.0`):
The parameter for calculating the Karras sigma schedule from the EDM
[paper](https://huggingface.co/papers/2206.00364).
solver_type (`str`, defaults to `"midpoint"`):
Solver type for the second-order solver. Must be one of `"midpoint"` or `"heun"`. The solver type slightly
affects the sample quality, especially for a small number of steps. It is recommended to use `"midpoint"`.
lower_order_final (`bool`, defaults to `True`):
Whether to use lower-order solvers in the final steps. Only valid for < 15 inference steps. This can
stabilize the sampling of DPMSolver for steps < 15, especially for steps <= 10.
Expand All @@ -65,8 +68,9 @@ class CosineDPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
richness. This can stabilize the sampling of the SDE variant of DPMSolver for small number of inference
steps, but sometimes may result in blurring.
final_sigmas_type (`str`, defaults to `"zero"`):
The final `sigma` value for the noise schedule during the sampling process. If `"sigma_min"`, the final
sigma is the same as the last sigma in the training schedule. If `zero`, the final sigma is set to 0.
The final `sigma` value for the noise schedule during the sampling process. Must be one of `"zero"` or
`"sigma_min"`. If `"sigma_min"`, the final sigma is the same as the last sigma in the training schedule. If
`"zero"`, the final sigma is set to 0.
"""

_compatibles = []
Expand All @@ -78,16 +82,16 @@ def __init__(
sigma_min: float = 0.3,
sigma_max: float = 500,
sigma_data: float = 1.0,
sigma_schedule: str = "exponential",
sigma_schedule: Literal["exponential", "karras"] = "exponential",
num_train_timesteps: int = 1000,
solver_order: int = 2,
prediction_type: str = "v_prediction",
prediction_type: Literal["epsilon", "sample", "v_prediction"] = "v_prediction",
rho: float = 7.0,
solver_type: str = "midpoint",
solver_type: Literal["midpoint", "heun"] = "midpoint",
lower_order_final: bool = True,
euler_at_final: bool = False,
final_sigmas_type: Optional[str] = "zero", # "zero", "sigma_min"
):
final_sigmas_type: Literal["zero", "sigma_min"] = "zero",
) -> None:
if solver_type not in ["midpoint", "heun"]:
if solver_type in ["logrho", "bh1", "bh2"]:
self.register_to_config(solver_type="midpoint")
Expand All @@ -113,26 +117,40 @@ def __init__(
self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication

@property
def init_noise_sigma(self):
# standard deviation of the initial noise distribution
def init_noise_sigma(self) -> float:
"""
The standard deviation of the initial noise distribution.

Returns:
`float`:
The initial noise sigma value computed as `sqrt(sigma_max^2 + 1)`.
"""
return (self.config.sigma_max**2 + 1) ** 0.5

@property
def step_index(self):
def step_index(self) -> Optional[int]:
"""
The index counter for current timestep. It will increase 1 after each scheduler step.

Returns:
`int` or `None`:
The current step index, or `None` if not yet initialized.
"""
return self._step_index

@property
def begin_index(self):
def begin_index(self) -> Optional[int]:
"""
The index for the first timestep. It should be set from pipeline with `set_begin_index` method.

Returns:
`int` or `None`:
The begin index, or `None` if not yet set.
"""
return self._begin_index

# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index
def set_begin_index(self, begin_index: int = 0):
def set_begin_index(self, begin_index: int = 0) -> None:
"""
Sets the begin index for the scheduler. This function should be run from pipeline before the inference.

Expand Down Expand Up @@ -161,7 +179,18 @@ def precondition_inputs(self, sample: torch.Tensor, sigma: Union[float, torch.Te
scaled_sample = sample * c_in
return scaled_sample

def precondition_noise(self, sigma):
def precondition_noise(self, sigma: Union[float, torch.Tensor]) -> torch.Tensor:
"""
Precondition the noise level by computing a normalized timestep representation.

Args:
sigma (`float` or `torch.Tensor`):
The sigma (noise level) value to precondition.

Returns:
`torch.Tensor`:
The preconditioned noise value computed as `atan(sigma) / pi * 2`.
"""
if not isinstance(sigma, torch.Tensor):
sigma = torch.tensor([sigma])

Expand Down Expand Up @@ -228,12 +257,14 @@ def scale_model_input(self, sample: torch.Tensor, timestep: Union[float, torch.T
self.is_scale_input_called = True
return sample

def set_timesteps(self, num_inference_steps: int = None, device: Union[str, torch.device] = None):
def set_timesteps(
self, num_inference_steps: Optional[int] = None, device: Optional[Union[str, torch.device]] = None
) -> None:
"""
Sets the discrete timesteps used for the diffusion chain (to be run before inference).

Args:
num_inference_steps (`int`):
num_inference_steps (`int`, *optional*):
The number of diffusion steps used when generating samples with a pre-trained model.
device (`str` or `torch.device`, *optional*):
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
Expand Down Expand Up @@ -334,7 +365,7 @@ def _compute_exponential_sigmas(
return sigmas

# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._sigma_to_t
def _sigma_to_t(self, sigma, log_sigmas):
def _sigma_to_t(self, sigma: np.ndarray, log_sigmas: np.ndarray) -> np.ndarray:
"""
Convert sigma values to corresponding timestep values through interpolation.

Expand Down Expand Up @@ -370,7 +401,19 @@ def _sigma_to_t(self, sigma, log_sigmas):
t = t.reshape(sigma.shape)
return t

def _sigma_to_alpha_sigma_t(self, sigma):
def _sigma_to_alpha_sigma_t(self, sigma: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Convert sigma to alpha and sigma_t values for the diffusion process.

Args:
sigma (`torch.Tensor`):
The sigma (noise level) value.

Returns:
`Tuple[torch.Tensor, torch.Tensor]`:
A tuple containing `alpha_t` (always 1 since inputs are pre-scaled) and `sigma_t` (same as input
sigma).
"""
alpha_t = torch.tensor(1) # Inputs are pre-scaled before going into unet, so alpha_t = 1
sigma_t = sigma

Expand Down Expand Up @@ -536,7 +579,7 @@ def index_for_timestep(
return step_index

# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler._init_step_index
def _init_step_index(self, timestep):
def _init_step_index(self, timestep: Union[int, torch.Tensor]) -> None:
"""
Initialize the step_index counter for the scheduler.

Expand All @@ -557,7 +600,7 @@ def step(
model_output: torch.Tensor,
timestep: Union[int, torch.Tensor],
sample: torch.Tensor,
generator=None,
generator: Optional[torch.Generator] = None,
return_dict: bool = True,
) -> Union[SchedulerOutput, Tuple]:
"""
Expand All @@ -567,20 +610,19 @@ def step(
Args:
model_output (`torch.Tensor`):
The direct output from learned diffusion model.
timestep (`int`):
timestep (`int` or `torch.Tensor`):
The current discrete timestep in the diffusion chain.
sample (`torch.Tensor`):
A current instance of a sample created by the diffusion process.
generator (`torch.Generator`, *optional*):
A random number generator.
return_dict (`bool`):
return_dict (`bool`, defaults to `True`):
Whether or not to return a [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`.

Returns:
[`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`:
If return_dict is `True`, [`~schedulers.scheduling_utils.SchedulerOutput`] is returned, otherwise a
tuple is returned where the first element is the sample tensor.

"""
if self.num_inference_steps is None:
raise ValueError(
Expand Down Expand Up @@ -702,5 +744,12 @@ def _get_conditioning_c_in(self, sigma: Union[float, torch.Tensor]) -> Union[flo
c_in = 1 / ((sigma**2 + self.config.sigma_data**2) ** 0.5)
return c_in

def __len__(self):
def __len__(self) -> int:
"""
Returns the number of training timesteps.

Returns:
`int`:
The number of training timesteps configured for the scheduler.
"""
return self.config.num_train_timesteps