Skip to content

Commit

Permalink
support diffuser==0.27.0
Browse files Browse the repository at this point in the history
  • Loading branch information
ResearcherXman committed Mar 31, 2024
1 parent 9d2762a commit e6b1f21
Show file tree
Hide file tree
Showing 3 changed files with 113 additions and 29 deletions.
84 changes: 84 additions & 0 deletions infer_img2img.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
import cv2
import torch
import numpy as np
from PIL import Image

from diffusers.utils import load_image
from diffusers.models import ControlNetModel

from insightface.app import FaceAnalysis
from pipeline_stable_diffusion_xl_instantid_img2img import StableDiffusionXLInstantIDImg2ImgPipeline, draw_kps

def resize_img(input_image, max_side=1280, min_side=1024, size=None,
pad_to_max_side=False, mode=Image.BILINEAR, base_pixel_number=64):

w, h = input_image.size
if size is not None:
w_resize_new, h_resize_new = size
else:
ratio = min_side / min(h, w)
w, h = round(ratio*w), round(ratio*h)
ratio = max_side / max(h, w)
input_image = input_image.resize([round(ratio*w), round(ratio*h)], mode)
w_resize_new = (round(ratio * w) // base_pixel_number) * base_pixel_number
h_resize_new = (round(ratio * h) // base_pixel_number) * base_pixel_number
input_image = input_image.resize([w_resize_new, h_resize_new], mode)

if pad_to_max_side:
res = np.ones([max_side, max_side, 3], dtype=np.uint8) * 255
offset_x = (max_side - w_resize_new) // 2
offset_y = (max_side - h_resize_new) // 2
res[offset_y:offset_y+h_resize_new, offset_x:offset_x+w_resize_new] = np.array(input_image)
input_image = Image.fromarray(res)
return input_image


if __name__ == "__main__":

# Load face encoder
app = FaceAnalysis(name='antelopev2', root='./', providers=['CUDAExecutionProvider', 'CPUExecutionProvider'])
app.prepare(ctx_id=0, det_size=(640, 640))

# Path to InstantID models
face_adapter = f'./checkpoints/ip-adapter.bin'
controlnet_path = f'./checkpoints/ControlNetModel'

# Load pipeline
controlnet = ControlNetModel.from_pretrained(controlnet_path, torch_dtype=torch.float16)

base_model_path = 'stabilityai/stable-diffusion-xl-base-1.0'

pipe = StableDiffusionXLInstantIDImg2ImgPipeline.from_pretrained(
base_model_path,
controlnet=controlnet,
torch_dtype=torch.float16,
)
pipe.cuda()
pipe.load_ip_adapter_instantid(face_adapter)

# Infer setting
prompt = "analog film photo of a man. faded film, desaturated, 35mm photo, grainy, vignette, vintage, Kodachrome, Lomography, stained, highly detailed, found footage, masterpiece, best quality"
n_prompt = "(lowres, low quality, worst quality:1.2), (text:1.2), watermark, painting, drawing, illustration, glitch, deformed, mutated, cross-eyed, ugly, disfigured (lowres, low quality, worst quality:1.2), (text:1.2), watermark, painting, drawing, illustration, glitch,deformed, mutated, cross-eyed, ugly, disfigured"

face_image = load_image("./examples/yann-lecun_resize.jpg")
face_image = resize_img(face_image)

face_info = app.get(cv2.cvtColor(np.array(face_image), cv2.COLOR_RGB2BGR))
face_info = sorted(face_info, key=lambda x:(x['bbox'][2]-x['bbox'][0])*(x['bbox'][3]-x['bbox'][1]))[-1] # only use the maximum face
face_emb = face_info['embedding']
face_kps = draw_kps(face_image, face_info['kps'])

image = pipe(
prompt=prompt,
negative_prompt=n_prompt,
image=face_image,
image_embeds=face_emb,
control_image=face_kps,
controlnet_conditioning_scale=0.8,
ip_adapter_scale=0.8,
num_inference_steps=30,
guidance_scale=5,
strength=0.85
).images[0]

image.save('result.jpg')
28 changes: 14 additions & 14 deletions pipeline_stable_diffusion_xl_instantid.py
Original file line number Diff line number Diff line change
Expand Up @@ -454,20 +454,20 @@ def __call__(

# 1. Check inputs. Raise error if not correct
self.check_inputs(
prompt,
prompt_2,
image,
callback_steps,
negative_prompt,
negative_prompt_2,
prompt_embeds,
negative_prompt_embeds,
pooled_prompt_embeds,
negative_pooled_prompt_embeds,
controlnet_conditioning_scale,
control_guidance_start,
control_guidance_end,
callback_on_step_end_tensor_inputs,
prompt=prompt,
prompt_2=prompt_2,
image=image,
callback_steps=callback_steps,
negative_prompt=negative_prompt,
negative_prompt_2=negative_prompt_2,
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
pooled_prompt_embeds=pooled_prompt_embeds,
negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
controlnet_conditioning_scale=controlnet_conditioning_scale,
control_guidance_start=control_guidance_start,
control_guidance_end=control_guidance_end,
callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
)

self._guidance_scale = guidance_scale
Expand Down
30 changes: 15 additions & 15 deletions pipeline_stable_diffusion_xl_instantid_full.py
Original file line number Diff line number Diff line change
Expand Up @@ -840,22 +840,22 @@ def __call__(

# 1. Check inputs. Raise error if not correct
self.check_inputs(
prompt,
prompt_2,
image,
callback_steps,
negative_prompt,
negative_prompt_2,
prompt_embeds,
negative_prompt_embeds,
pooled_prompt_embeds,
negative_pooled_prompt_embeds,
controlnet_conditioning_scale,
control_guidance_start,
control_guidance_end,
callback_on_step_end_tensor_inputs,
prompt=prompt,
prompt_2=prompt_2,
image=image,
callback_steps=callback_steps,
negative_prompt=negative_prompt,
negative_prompt_2=negative_prompt_2,
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
pooled_prompt_embeds=pooled_prompt_embeds,
negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
controlnet_conditioning_scale=controlnet_conditioning_scale,
control_guidance_start=control_guidance_start,
control_guidance_end=control_guidance_end,
callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
)

self._guidance_scale = guidance_scale
self._clip_skip = clip_skip
self._cross_attention_kwargs = cross_attention_kwargs
Expand Down

0 comments on commit e6b1f21

Please sign in to comment.