From 62ef61fce85e0baab015f411d480de96b3f9445c Mon Sep 17 00:00:00 2001 From: rlsn Date: Tue, 2 Jan 2024 22:09:37 +0900 Subject: [PATCH] Initial commit --- .gitattributes | 2 + tripper.py | 164 +++++++++++++++++++++++++++++++++++++++++++++++ utils.py | 168 +++++++++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 334 insertions(+) create mode 100644 .gitattributes create mode 100644 tripper.py create mode 100644 utils.py diff --git a/.gitattributes b/.gitattributes new file mode 100644 index 0000000..dfe0770 --- /dev/null +++ b/.gitattributes @@ -0,0 +1,2 @@ +# Auto detect text files and perform LF normalization +* text=auto diff --git a/tripper.py b/tripper.py new file mode 100644 index 0000000..9a3fccf --- /dev/null +++ b/tripper.py @@ -0,0 +1,164 @@ +import diffusers +from diffusers import (StableDiffusionPipeline, StableDiffusionImg2ImgPipeline) +import torch +from utils import * +import os,json + +class Tripper(object): + def __init__(self, model_file): + txt2img_pipe = StableDiffusionPipeline.from_ckpt(model_file, torch_dtype=torch.float16) + txt2img_pipe.safety_checker = lambda images,**kwargs: (images, [False] * len(images)) + img2img_pipe = StableDiffusionImg2ImgPipeline(**txt2img_pipe.components) + img2img_pipe.safety_checker = lambda images,**kwags: (images, [False] * len(images)) + + self.txt2img_pipe = txt2img_pipe.to('cuda') + self.img2img_pipe = img2img_pipe.to("cuda") + self.loras = dict() + + def scheduler(self): + return self.txt2img_pipe.scheduler + def show_schedulers(self): + return self.txt2img_pipe.scheduler.compatibles + def set_scheduler(self, scheduler_cls): + self.txt2img_pipe.scheduler = scheduler_cls.from_config(self.txt2img_pipe.scheduler.config) + self.img2img_pipe.scheduler = scheduler_cls.from_config(self.img2img_pipe.scheduler.config) + + def load_lora(self, lora_dict): + for lora in lora_dict: + if lora not in self.loras: + self.txt2img_pipe = load_lora_weights(self.txt2img_pipe, lora, lora_dict[lora], 'cuda', torch.float32, load=True) + self.loras[lora] = lora_dict[lora] + print(f"loaded {lora}") + else: + print(f"already loaded {lora}") + def unload_lora(self, lora_dict): + for lora in lora_dict: + if lora in self.loras: + self.txt2img_pipe = load_lora_weights(self.txt2img_pipe, lora, lora_dict[lora], 'cuda', torch.float32, load=False) + del self.loras[lora] + print(f"unloaded {lora}") + else: + print(f"have not loaded {lora}") + + def txt2img(self, prompt, negative_prompt, lora_dict, + width=512, height=768, num_img=6, guidance_scale=7, num_inference_steps=25, + out_dir="preview"): + os.makedirs(out_dir, exist_ok = True) + + self.load_lora(lora_dict) + + prompt = clean_prompt(prompt) + prompt_embeds, negative_prompt_embeds = convert_prompt_embeds(self.txt2img_pipe, prompt, negative_prompt) + images = self.txt2img_pipe(prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + guidance_scale=guidance_scale, + num_images_per_prompt=num_img, + num_inference_steps=num_inference_steps, + height=height, width=width, + ).images + for i,img in enumerate(images): + fn = f"{out_dir}/{timestr()}_{i}.jpg" + img.convert("RGB").save(fn) + self.unload_lora(lora_dict) + + return images + + def img2img(self, image, prompt, negative_prompt, lora_dict, strength=0.5, + num_img=6, guidance_scale=7, num_inference_steps=25, + out_dir="preview"): + os.makedirs(out_dir, exist_ok = True) + + self.load_lora(lora_dict) + + prompt = clean_prompt(prompt) + prompt_embeds, negative_prompt_embeds = convert_prompt_embeds(self.txt2img_pipe, prompt, negative_prompt) + images = self.img2img_pipe(prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + image=image, + strength = strength, + num_images_per_prompt=num_img, + guidance_scale=guidance_scale, + num_inference_steps=num_inference_steps).images + for i,img in enumerate(images): + fn = f"{out_dir}/{timestr()}_{i}.jpg" + img.convert("RGB").save(fn) + + self.unload_lora(lora_dict) + return images + + def generate_video(self, init_image, prompt, negative_prompt, + lora_dict, nsteps, strength_schedule, + transform_fn, + guidance_scale=7, + num_inference_steps=40, + out_dir="preview"): + + os.makedirs(out_dir, exist_ok = True) + + with open(f"{out_dir}/config.json","w") as fp: + config = {"prompt":prompt, + "negative_prompt":negative_prompt, + "loras":lora_dict, + "guidance_scale":guidance_scale, + "num_inference_steps":num_inference_steps, + } + json.dump(config,fp,indent=4) + + images = [init_image] + self.load_lora(lora_dict) + + + prompt = clean_prompt(prompt) + prompt_embeds, negative_prompt_embeds = convert_prompt_embeds(self.txt2img_pipe, prompt, negative_prompt) + for s in range(nsteps): + print(f"{s}/{nsteps}") + image = transform_fn(images[-1], s) + images += self.img2img_pipe(prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + image=image, + strength = strength_schedule[s], + guidance_scale=guidance_scale, + num_inference_steps=num_inference_steps).images + + fn = out_dir+"/%06d.jpg"%s + images[-1].convert("RGB").save(fn) + + self.unload_lora(lora_dict) + return images + + + # def batch_generate(self, general_prompt, character_dict, addition_list, lora_dict, + # negative_prompt, img_per_comb=6, save_dir=".", guidance_scale=7, num_inference_steps=25): + # for character in character_dict: + # try: + # pipeline = load_lora_weights(pipeline, character, 1., 'cuda', torch.float32, load=True) + # print(f"loaded {character}") + # except: + # continue + # for lora in lora_dict: + # try: + # pipeline = load_lora_weights(pipeline, lora, 1., 'cuda', torch.float32, load=True) + # print(f"loaded {lora}") + # except: + # continue + # for addition in addition_list: + # width = lora_dict[lora][1] + # height = lora_dict[lora][2] + # prompt = general_prompt + lora_dict[lora][0] + addition + character_dict[character] + # prompt = clean_prompt(prompt) + # prompt_embeds, negative_prompt_embeds = convert_prompt_embeds(pipeline, prompt, negative_prompt) + # images = txt2img_pipe(prompt_embeds=prompt_embeds, + # negative_prompt_embeds=negative_prompt_embeds, + # guidance_scale=guidance_scale, + # num_images_per_prompt=img_per_comb, + # num_inference_steps=num_inference_steps, + # height=height, width=width, + # ).images + # for img in images: + # fn = f"{save_dir}/{character.split('.')[0]}_{lora.split('.')[0]}_{int(np.random.rand()*1e6)}.jpg" + # img.convert("RGB").save(fn) + # print(f"saved {fn}") + # pipeline = load_lora_weights(pipeline, lora, 1., 'cuda', torch.float32, load=False) + # print(f"unloaded {lora}") + # pipeline = load_lora_weights(pipeline, character, 1., 'cuda', torch.float32, load=False) + # print(f"unloaded {character}") \ No newline at end of file diff --git a/utils.py b/utils.py new file mode 100644 index 0000000..456a89e --- /dev/null +++ b/utils.py @@ -0,0 +1,168 @@ +from collections import defaultdict +from einops import einsum +import torch +import PIL +from PIL import Image +import numpy as np +import time +from safetensors.torch import load_file +from scipy.stats import norm + +def timestr(): + return time.strftime('%Y%m%d%H%M%S', time.localtime()) + +def image_grid(imgs, rows, cols): + assert len(imgs) == rows*cols + + w, h = imgs[0].size + grid = Image.new('RGB', size=(cols*w, rows*h)) + grid_w, grid_h = grid.size + + for i, img in enumerate(imgs): + grid.paste(img, box=(i%cols*w, i//cols*h)) + return grid + +def export_as_gif(filename, images, frames_per_second=10, rubber_band=False): + if rubber_band: + images += images[2:-1][::-1] + images[0].save( + filename, + save_all=True, + append_images=images[1:], + duration=1000 // frames_per_second, + loop=0, + ) + +def load_lora_weights(pipeline, checkpoint_path, multiplier, device, dtype, load=True): + LORA_PREFIX_UNET = "lora_unet" + LORA_PREFIX_TEXT_ENCODER = "lora_te" + # load LoRA weight from .safetensors + state_dict = load_file(checkpoint_path, device=device) + + updates = defaultdict(dict) + for key, value in state_dict.items(): + # it is suggested to print out the key, it usually will be something like below + # "lora_te_text_model_encoder_layers_0_self_attn_k_proj.lora_down.weight" + + layer, elem = key.split('.', 1) + updates[layer][elem] = value + + # directly update weight in diffusers model + for layer, elems in updates.items(): + + if "text" in layer: + layer_infos = layer.split(LORA_PREFIX_TEXT_ENCODER + "_")[-1].split("_") + curr_layer = pipeline.text_encoder + else: + layer_infos = layer.split(LORA_PREFIX_UNET + "_")[-1].split("_") + curr_layer = pipeline.unet + + # find the target layer + temp_name = layer_infos.pop(0) + while len(layer_infos) > -1: + try: + curr_layer = curr_layer.__getattr__(temp_name) + if len(layer_infos) > 0: + temp_name = layer_infos.pop(0) + elif len(layer_infos) == 0: + break + except Exception: + if len(temp_name) > 0: + temp_name += "_" + layer_infos.pop(0) + else: + temp_name = layer_infos.pop(0) + + # get elements for this layer + weight_up = elems['lora_up.weight'].to(dtype) + weight_down = elems['lora_down.weight'].to(dtype) + alpha = elems['alpha'] + if alpha: + alpha = alpha.item() / weight_up.shape[1] + else: + alpha = 1.0 + + # update weight + if len(weight_up.shape) == 4: + x = einsum(weight_up.expand(-1,-1,weight_down.size(2), weight_down.size(3)),weight_down,"c1 k h w, k c2 h w -> c1 c2 h w") + if load: + curr_layer.weight.data += multiplier * alpha * x + # curr_layer.weight.data += multiplier * alpha * torch.mm(weight_up.squeeze(3).squeeze(2), weight_down.squeeze(3).squeeze(2)).unsqueeze(2).unsqueeze(3) + else: + curr_layer.weight.data -= multiplier * alpha * x + # curr_layer.weight.data -= multiplier * alpha * torch.mm(weight_up.squeeze(3).squeeze(2), weight_down.squeeze(3).squeeze(2)).unsqueeze(2).unsqueeze(3) + else: + if load: + curr_layer.weight.data += multiplier * alpha * torch.mm(weight_up, weight_down) + else: + curr_layer.weight.data -= multiplier * alpha * torch.mm(weight_up, weight_down) + + return pipeline + + +def convert_prompt_embeds(pipe, prompt,negative_prompt): + max_length = pipe.tokenizer.model_max_length + + input_ids = pipe.tokenizer(prompt, return_tensors="pt").input_ids + + negative_ids = pipe.tokenizer(negative_prompt, return_tensors="pt").input_ids + if input_ids.shape[-1]>negative_ids.shape[-1]: + negative_ids = pipe.tokenizer(negative_prompt, truncation=False, padding="max_length", max_length=input_ids.shape[-1], return_tensors="pt").input_ids + elif input_ids.shape[-1]