Skip to content

Commit

Permalink
Update img2imgalt.py
Browse files Browse the repository at this point in the history
WIP
  • Loading branch information
arrmansa committed Dec 30, 2024
1 parent 64a8f9d commit a63cf10
Showing 1 changed file with 22 additions and 11 deletions.
33 changes: 22 additions & 11 deletions scripts/img2imgalt.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,12 @@ def find_noise_for_image(p, cond, uncond, cfg_scale, steps):
t = dnw.sigma_to_t(sigma_in)

if shared.sd_model.is_sdxl:
eps = shared.sd_model.model(x_in * c_in, t, {"crossattn": cond_in["c_crossattn"][0]} )
num_classes_hack = shared.sd_model.model.diffusion_model.num_classes
shared.sd_model.model.diffusion_model.num_classes = None
try:
eps = shared.sd_model.model(x_in * c_in, t, {"crossattn": cond_in["c_crossattn"][0]} )
finally:
shared.sd_model.model.diffusion_model.num_classes = num_classes_hack
else:
eps = shared.sd_model.apply_model(x_in * c_in, t, cond=cond_in)

Expand Down Expand Up @@ -78,13 +83,6 @@ def find_noise_for_image(p, cond, uncond, cfg_scale, steps):

# Based on changes suggested by briansemrau in https://github.com/AUTOMATIC1111/stable-diffusion-webui/issues/736
def find_noise_for_image_sigma_adjustment(p, cond, uncond, cfg_scale, steps):
if shared.sd_model.is_sdxl:
cond_tensor = cond['crossattn']
uncond_tensor = uncond['crossattn']
cond_in = torch.cat([uncond_tensor, cond_tensor])
else:
cond_in = torch.cat([uncond, cond])

x = p.init_latent

s_in = x.new_ones([x.shape[0]])
Expand Down Expand Up @@ -124,7 +122,12 @@ def find_noise_for_image_sigma_adjustment(p, cond, uncond, cfg_scale, steps):


if shared.sd_model.is_sdxl:
eps = shared.sd_model.model(x_in * c_in, t, {"crossattn": cond_in["c_crossattn"][0]} )
num_classes_hack = shared.sd_model.model.diffusion_model.num_classes
shared.sd_model.model.diffusion_model.num_classes = None
try:
eps = shared.sd_model.model(x_in * c_in, t, {"crossattn": cond_in["c_crossattn"][0]} )
finally:
shared.sd_model.model.diffusion_model.num_classes = num_classes_hack
else:
eps = shared.sd_model.apply_model(x_in * c_in, t, cond=cond_in)

Expand Down Expand Up @@ -211,9 +214,19 @@ def sample_extra(conditioning, unconditional_conditioning, seeds, subseeds, subs
and self.cache.sigma_adjustment == sigma_adjustment
same_everything = same_params and self.cache.latent.shape == lat.shape and np.abs(self.cache.latent-lat).sum() < 100

rand_noise = processing.create_random_tensors(p.init_latent.shape[1:], seeds=seeds, subseeds=subseeds, subseed_strength=p.subseed_strength, seed_resize_from_h=p.seed_resize_from_h, seed_resize_from_w=p.seed_resize_from_w, p=p)

if same_everything:
rec_noise = self.cache.noise
else:
# This prevents a crash, because I don't know how to access the underlying .diffusion_model yet when controlnet is enabled. WIP
# modules.sd_unet -> we're good
# scripts.hook -> we're cooked
if "scripts.hook" in str(shared.sd_model.model.diffusion_model.forward.__module__):
print("turn off any controlnets, do 1 pass and then turn controlnet back on to cache noise")
p.steps = 1
return sd_samplers.create_sampler(p.sampler_name, p.sd_model).sample_img2img(p, p.init_latent, rand_noise, conditioning, unconditional_conditioning, image_conditioning=p.image_conditioning)

shared.state.job_count += 1
cond = p.sd_model.get_learned_conditioning(p.batch_size * [original_prompt])
uncond = p.sd_model.get_learned_conditioning(p.batch_size * [original_negative_prompt])
Expand All @@ -223,8 +236,6 @@ def sample_extra(conditioning, unconditional_conditioning, seeds, subseeds, subs
rec_noise = find_noise_for_image(p, cond, uncond, cfg, st)
self.cache = Cached(rec_noise, cfg, st, lat, original_prompt, original_negative_prompt, sigma_adjustment)

rand_noise = processing.create_random_tensors(p.init_latent.shape[1:], seeds=seeds, subseeds=subseeds, subseed_strength=p.subseed_strength, seed_resize_from_h=p.seed_resize_from_h, seed_resize_from_w=p.seed_resize_from_w, p=p)

combined_noise = ((1 - randomness) * rec_noise + randomness * rand_noise) / ((randomness**2 + (1-randomness)**2) ** 0.5)

sampler = sd_samplers.create_sampler(p.sampler_name, p.sd_model)
Expand Down

0 comments on commit a63cf10

Please sign in to comment.