@@ -284,6 +284,53 @@ output = pipeline(
284
284
output_images = pipeline.numpy_to_pil(np.asarray(output.reshape((num_samples,) + output.shape[- 3 :])))
285
285
```
286
286
287
+ Diffusers also has a Text-guided inpainting pipeline with Flax/Jax
288
+
289
+ ``` python
290
+ import jax
291
+ import numpy as np
292
+ from flax.jax_utils import replicate
293
+ from flax.training.common_utils import shard
294
+ import PIL
295
+ import requests
296
+ from io import BytesIO
297
+
298
+
299
+ from diffusers import FlaxStableDiffusionInpaintPipeline
300
+
301
+ def download_image (url ):
302
+ response = requests.get(url)
303
+ return PIL .Image.open(BytesIO(response.content)).convert(" RGB" )
304
+ img_url = " https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo.png"
305
+ mask_url = " https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo_mask.png"
306
+
307
+ init_image = download_image(img_url).resize((512 , 512 ))
308
+ mask_image = download_image(mask_url).resize((512 , 512 ))
309
+
310
+ pipeline, params = FlaxStableDiffusionInpaintPipeline.from_pretrained(" xvjiarui/stable-diffusion-2-inpainting" )
311
+
312
+ prompt = " Face of a yellow cat, high resolution, sitting on a park bench"
313
+ prng_seed = jax.random.PRNGKey(0 )
314
+ num_inference_steps = 50
315
+
316
+ num_samples = jax.device_count()
317
+ prompt = num_samples * [prompt]
318
+ init_image = num_samples * [init_image]
319
+ mask_image = num_samples * [mask_image]
320
+ prompt_ids, processed_masked_images, processed_masks = pipeline.prepare_inputs(prompt, init_image, mask_image)
321
+
322
+
323
+ # shard inputs and rng
324
+ params = replicate(params)
325
+ prng_seed = jax.random.split(prng_seed, jax.device_count())
326
+ prompt_ids = shard(prompt_ids)
327
+ processed_masked_images = shard(processed_masked_images)
328
+ processed_masks = shard(processed_masks)
329
+
330
+ images = pipeline(prompt_ids, processed_masks, processed_masked_images, params, prng_seed, num_inference_steps, jit = True ).images
331
+ images = pipeline.numpy_to_pil(np.asarray(images.reshape((num_samples,) + images.shape[- 3 :])))
332
+ ```
333
+
287
334
### Image-to-Image text-guided generation with Stable Diffusion
288
335
289
336
The ` StableDiffusionImg2ImgPipeline ` lets you pass a text prompt and an initial image to condition the generation of new images.
0 commit comments