Skip to content

Commit a43bdd0

Browse files
authored
[Flax] Add Flax inpainting impl (huggingface#1966)
* [Flax] Add Flax inpainting impl * fixed copies, add README.md * fixed README.md * add test * format * update README.md
1 parent f77ff56 commit a43bdd0

File tree

7 files changed

+679
-2
lines changed

7 files changed

+679
-2
lines changed

README.md

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -284,6 +284,53 @@ output = pipeline(
284284
output_images = pipeline.numpy_to_pil(np.asarray(output.reshape((num_samples,) + output.shape[-3:])))
285285
```
286286

287+
Diffusers also has a Text-guided inpainting pipeline with Flax/Jax
288+
289+
```python
290+
import jax
291+
import numpy as np
292+
from flax.jax_utils import replicate
293+
from flax.training.common_utils import shard
294+
import PIL
295+
import requests
296+
from io import BytesIO
297+
298+
299+
from diffusers import FlaxStableDiffusionInpaintPipeline
300+
301+
def download_image(url):
302+
response = requests.get(url)
303+
return PIL.Image.open(BytesIO(response.content)).convert("RGB")
304+
img_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo.png"
305+
mask_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo_mask.png"
306+
307+
init_image = download_image(img_url).resize((512, 512))
308+
mask_image = download_image(mask_url).resize((512, 512))
309+
310+
pipeline, params = FlaxStableDiffusionInpaintPipeline.from_pretrained("xvjiarui/stable-diffusion-2-inpainting")
311+
312+
prompt = "Face of a yellow cat, high resolution, sitting on a park bench"
313+
prng_seed = jax.random.PRNGKey(0)
314+
num_inference_steps = 50
315+
316+
num_samples = jax.device_count()
317+
prompt = num_samples * [prompt]
318+
init_image = num_samples * [init_image]
319+
mask_image = num_samples * [mask_image]
320+
prompt_ids, processed_masked_images, processed_masks = pipeline.prepare_inputs(prompt, init_image, mask_image)
321+
322+
323+
# shard inputs and rng
324+
params = replicate(params)
325+
prng_seed = jax.random.split(prng_seed, jax.device_count())
326+
prompt_ids = shard(prompt_ids)
327+
processed_masked_images = shard(processed_masked_images)
328+
processed_masks = shard(processed_masks)
329+
330+
images = pipeline(prompt_ids, processed_masks, processed_masked_images, params, prng_seed, num_inference_steps, jit=True).images
331+
images = pipeline.numpy_to_pil(np.asarray(images.reshape((num_samples,) + images.shape[-3:])))
332+
```
333+
287334
### Image-to-Image text-guided generation with Stable Diffusion
288335

289336
The `StableDiffusionImg2ImgPipeline` lets you pass a text prompt and an initial image to condition the generation of new images.

src/diffusers/__init__.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -182,4 +182,8 @@
182182
except OptionalDependencyNotAvailable:
183183
from .utils.dummy_flax_and_transformers_objects import * # noqa F403
184184
else:
185-
from .pipelines import FlaxStableDiffusionImg2ImgPipeline, FlaxStableDiffusionPipeline
185+
from .pipelines import (
186+
FlaxStableDiffusionImg2ImgPipeline,
187+
FlaxStableDiffusionInpaintPipeline,
188+
FlaxStableDiffusionPipeline,
189+
)

src/diffusers/pipelines/__init__.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -108,4 +108,8 @@
108108
except OptionalDependencyNotAvailable:
109109
from ..utils.dummy_flax_and_transformers_objects import * # noqa F403
110110
else:
111-
from .stable_diffusion import FlaxStableDiffusionImg2ImgPipeline, FlaxStableDiffusionPipeline
111+
from .stable_diffusion import (
112+
FlaxStableDiffusionImg2ImgPipeline,
113+
FlaxStableDiffusionInpaintPipeline,
114+
FlaxStableDiffusionPipeline,
115+
)

src/diffusers/pipelines/stable_diffusion/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -99,4 +99,5 @@ class FlaxStableDiffusionPipelineOutput(BaseOutput):
9999
from ...schedulers.scheduling_pndm_flax import PNDMSchedulerState
100100
from .pipeline_flax_stable_diffusion import FlaxStableDiffusionPipeline
101101
from .pipeline_flax_stable_diffusion_img2img import FlaxStableDiffusionImg2ImgPipeline
102+
from .pipeline_flax_stable_diffusion_inpaint import FlaxStableDiffusionInpaintPipeline
102103
from .safety_checker_flax import FlaxStableDiffusionSafetyChecker

0 commit comments

Comments
 (0)