From 04d764820e5cc526092e1aee36b4e78182c42d54 Mon Sep 17 00:00:00 2001 From: Manuel Schmid <9307310+mashb1t@users.noreply.github.com> Date: Thu, 6 Jun 2024 13:42:26 +0200 Subject: [PATCH] fix: correctly set alphas_cumprod (#3106) --- ldm_patched/modules/model_sampling.py | 6 ++++++ modules/patch_precision.py | 2 ++ 2 files changed, 8 insertions(+) diff --git a/ldm_patched/modules/model_sampling.py b/ldm_patched/modules/model_sampling.py index bd8cb18c2..8971b4e6e 100644 --- a/ldm_patched/modules/model_sampling.py +++ b/ldm_patched/modules/model_sampling.py @@ -1,6 +1,7 @@ import torch from ldm_patched.ldm.modules.diffusionmodules.util import make_beta_schedule import math +import numpy as np class EPS: def calculate_input(self, sigma, noise): @@ -69,12 +70,17 @@ def _register_schedule(self, given_betas=None, beta_schedule="linear", timesteps # self.register_buffer('alphas_cumprod_prev', torch.tensor(alphas_cumprod_prev, dtype=torch.float32)) sigmas = ((1 - alphas_cumprod) / alphas_cumprod) ** 0.5 + alphas_cumprod = torch.tensor(np.cumprod(alphas, axis=0), dtype=torch.float32) self.set_sigmas(sigmas) + self.set_alphas_cumprod(alphas_cumprod.float()) def set_sigmas(self, sigmas): self.register_buffer('sigmas', sigmas.float()) self.register_buffer('log_sigmas', sigmas.log().float()) + def set_alphas_cumprod(self, alphas_cumprod): + self.register_buffer("alphas_cumprod", alphas_cumprod.float()) + @property def sigma_min(self): return self.sigmas[0] diff --git a/modules/patch_precision.py b/modules/patch_precision.py index 83569bdd1..22ffda0ad 100644 --- a/modules/patch_precision.py +++ b/modules/patch_precision.py @@ -51,6 +51,8 @@ def patched_register_schedule(self, given_betas=None, beta_schedule="linear", ti self.linear_end = linear_end sigmas = torch.tensor(((1 - alphas_cumprod) / alphas_cumprod) ** 0.5, dtype=torch.float32) self.set_sigmas(sigmas) + alphas_cumprod = torch.tensor(alphas_cumprod, dtype=torch.float32) + self.set_alphas_cumprod(alphas_cumprod) return