From 721757013f0ca618508714e71d0117d3f331af71 Mon Sep 17 00:00:00 2001 From: Ean Garvey <87458719+monorimet@users.noreply.github.com> Date: Fri, 26 Jan 2024 11:33:06 -0600 Subject: [PATCH] SDXL tweaks (#2072) --- .../pipeline_shark_stable_diffusion_utils.py | 2 +- .../src/schedulers/sd_schedulers.py | 15 ++++++++ .../web/ui/txt2img_sdxl_ui.py | 2 + apps/stable_diffusion/web/ui/utils.py | 38 +++++++++++++------ 4 files changed, 45 insertions(+), 12 deletions(-) diff --git a/apps/stable_diffusion/src/pipelines/pipeline_shark_stable_diffusion_utils.py b/apps/stable_diffusion/src/pipelines/pipeline_shark_stable_diffusion_utils.py index 2635dce234..4c60c36096 100644 --- a/apps/stable_diffusion/src/pipelines/pipeline_shark_stable_diffusion_utils.py +++ b/apps/stable_diffusion/src/pipelines/pipeline_shark_stable_diffusion_utils.py @@ -609,7 +609,7 @@ def decode_latents_sdxl(self, latents, is_fp32_vae): # latents are in unet dtype here so switch if we want to use fp32 if is_fp32_vae: print("Casting latents to float32 for VAE") - latents = latents.to(torch.float32) + latents = latents.astype("float32") images = self.vae("forward", (latents,)) images = (torch.from_numpy(images) / 2 + 0.5).clamp(0, 1) images = images.cpu().permute(0, 2, 3, 1).float().numpy() diff --git a/apps/stable_diffusion/src/schedulers/sd_schedulers.py b/apps/stable_diffusion/src/schedulers/sd_schedulers.py index 913b15c9e2..526c0de3f8 100644 --- a/apps/stable_diffusion/src/schedulers/sd_schedulers.py +++ b/apps/stable_diffusion/src/schedulers/sd_schedulers.py @@ -81,6 +81,21 @@ def get_schedulers(model_id): algorithm_type="dpmsolver++", use_karras_sigmas=True, ) + schedulers[ + "DPMSolverSDEKarras++" + ] = DPMSolverMultistepScheduler.from_pretrained( + model_id, + subfolder="scheduler", + algorithm_type="sde-dpmsolver++", + use_karras_sigmas=True, + ) + schedulers[ + "DPMSolverSDE++" + ] = DPMSolverMultistepScheduler.from_pretrained( + model_id, + subfolder="scheduler", + algorithm_type="sde-dpmsolver++", + ) schedulers["EulerDiscrete"] = EulerDiscreteScheduler.from_pretrained( model_id, subfolder="scheduler", diff --git a/apps/stable_diffusion/web/ui/txt2img_sdxl_ui.py b/apps/stable_diffusion/web/ui/txt2img_sdxl_ui.py index 2fd1cf5257..9550499401 100644 --- a/apps/stable_diffusion/web/ui/txt2img_sdxl_ui.py +++ b/apps/stable_diffusion/web/ui/txt2img_sdxl_ui.py @@ -361,6 +361,8 @@ def txt2img_sdxl_inf( "DPMSolverMultistep++", "DPMSolverMultistepKarras", "DPMSolverMultistepKarras++", + "DPMSolverSDEKarras++", + "DPMSolverSDE++", "EulerDiscrete", "LCMScheduler", ], diff --git a/apps/stable_diffusion/web/ui/utils.py b/apps/stable_diffusion/web/ui/utils.py index ea11bd635d..59ea1a0a23 100644 --- a/apps/stable_diffusion/web/ui/utils.py +++ b/apps/stable_diffusion/web/ui/utils.py @@ -340,21 +340,37 @@ def mask_editor_value_for_gallery_data(gallery_data): default_configs = { + # "stabilityai/sdxl-turbo": [ + # gr.Textbox(label="", interactive=False, value=None, visible=False), + # gr.Textbox( + # label="Prompt", + # value="masterpiece, a graceful shark leaping out of the water to catch a fish, eclipsing the sunset, epic, rays of light, silhouette", + # ), + # gr.Slider(0, 10, value=2), + # "EulerAncestralDiscrete", + # gr.Slider(0, value=0), + # 512, + # 512, + # "madebyollin/sdxl-vae-fp16-fix", + # gr.Checkbox( + # label="Auto-Generate", visible=False, interactive=True, value=False + # ), + # ], "stabilityai/sdxl-turbo": [ - gr.Textbox(label="", interactive=False, value=None, visible=False), - gr.Textbox( - label="Prompt", - value="masterpiece, a graceful shark leaping out of the water to catch a fish, eclipsing the sunset, epic, rays of light, silhouette", - ), - gr.Slider(0, 10, value=2), + gr.Textbox(label="Prompt", interactive=True, visible=True), + gr.Textbox(label="Negative Prompt", interactive=True), + 4, "EulerAncestralDiscrete", - gr.Slider(0, value=0), - 512, - 512, + 0, + gr.Slider(value=768, interactive=True), + gr.Slider(value=768, interactive=True), "madebyollin/sdxl-vae-fp16-fix", gr.Checkbox( - label="Auto-Generate", visible=False, interactive=True, value=False - ), + label="Auto-Generate", + visible=False, + interactive=False, + value=False, + ) ], "stabilityai/stable-diffusion-xl-base-1.0": [ gr.Textbox(label="Prompt", interactive=True, visible=True),