Skip to content

StableDiffusionXLControlNetUnionInpaintPipeline requires a float for controlnet_conditioning_scale & control_guidance_end #11828

Open
@ewwwgiddings

Description

@ewwwgiddings

Describe the bug

Summary

StableDiffusionXLControlNetUnionInpaintPipeline requires a float for controlnet_conditioning_scale & control_guidance_end when they should also accept a List[float].

After PR #10723 the txt-to-img Union pipeline pipeline_controlnet_union_sd_xl.py accepts a list/tuple so each active ControlNet branch can have its own conditioning scale. The in-paint counterpart pipeline_controlnet_union_inpaint_sd_xl.py still contains the old check:

elif isinstance(self.controlnet, ControlNetUnionModel):
    if not isinstance(controlnet_conditioning_scale, float):
        raise TypeError(
            "For single controlnet: `controlnet_conditioning_scale` must be type `float`."
        )

Passing a list raises TypeError: For single controlnet: controlnet_conditioning_scalemust be typefloat.

It was said that it would be added in that PR: #10723 (comment) reply to #10723 (comment)

I don't see any mention of the control_guidance_end behavior in any PR's so maybe it was missed.

Expected Behaviour

Reproduction

from diffusers import StableDiffusionXLControlNetUnionInpaintPipeline, ControlNetUnionModel, AutoencoderKL
from diffusers.utils import load_image
import torch
import numpy as np
from PIL import Image

prompt = "A cat"
# download an image
image = load_image(
    "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/in_paint/overture-creations-5sI6fQgYIuo.png"
).resize((1024, 1024))
mask = load_image(
    "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/in_paint/overture-creations-5sI6fQgYIuo_mask.png"
).resize((1024, 1024))
# initialize the models and pipeline
controlnet = ControlNetUnionModel.from_pretrained(
    "brad-twinkl/controlnet-union-sdxl-1.0-promax", torch_dtype=torch.float16
)
vae = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16)
pipe = StableDiffusionXLControlNetUnionInpaintPipeline.from_pretrained(
    "stabilityai/stable-diffusion-xl-base-1.0",
    controlnet=controlnet,
    vae=vae,
    torch_dtype=torch.float16,
    variant="fp16",
)
pipe.enable_model_cpu_offload()
controlnet_img = image.copy()
controlnet_img_np = np.array(controlnet_img)
mask_np = np.array(mask)
controlnet_img_np[mask_np > 0] = 0
controlnet_img = Image.fromarray(controlnet_img_np)
# generate image
image = pipe(prompt, image=image, mask_image=mask, control_image=[controlnet_img], control_mode=[7], controlnet_conditioning_scale=[1.0], control_guidance_end=[1.0]).images[0]
image.save("inpaint.png")

Logs

Traceback (most recent call last):
  File "source\background_reference_cnet.py", line 146, in <module>
    BackgroundGeneration("images/inputs/", "images/outputs/", "test.jpg", "test.png", "RunDiffusion/Juggernaut-XL-v9")
  File "source\background_reference_cnet.py", line 113, in BackgroundGeneration
    base_result = pipeline(
                  ^^^^^^^^^
  File "source\.venv\Lib\site-packages\torch\utils\_contextlib.py", line 116, in decorate_context
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "source\.venv\Lib\site-packages\diffusers\pipelines\controlnet\pipeline_controlnet_union_inpaint_sd_xl.py", line 1372, in __call__
    self.check_inputs(
  File "source\.venv\Lib\site-packages\diffusers\pipelines\controlnet\pipeline_controlnet_union_inpaint_sd_xl.py", line 770, in check_inputs
    raise TypeError("For single controlnet: `controlnet_conditioning_scale` must be type `float`.")
TypeError: For single controlnet: `controlnet_conditioning_scale` must be type `float`.

System Info

  • 🤗 Diffusers version: 0.32.2
  • Platform: Windows-11-10.0.26100-SP0
  • Running on Google Colab?: No
  • Python version: 3.12.10
  • PyTorch version (GPU?): 2.5.1+cu121 (True)
  • Flax version (CPU?/GPU?/TPU?): not installed (NA)
  • Jax version: not installed
  • JaxLib version: not installed
  • Huggingface_hub version: 0.33.1
  • Transformers version: 4.42.3
  • Accelerate version: 1.4.0
  • PEFT version: 0.9.0
  • Bitsandbytes version: not installed
  • Safetensors version: 0.5.3
  • xFormers version: not installed
  • Accelerator: NVIDIA RTX 3500 Ada Generation Laptop GPU, 12282 MiB
  • Using GPU in script?:
  • Using distributed or parallel set-up in script?:

Who can help?

@yiyixuxu @sayakpaul

Metadata

Metadata

Assignees

Labels

bugSomething isn't working

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions