Skip to content

Commit

Permalink
Initial commit
Browse files Browse the repository at this point in the history
  • Loading branch information
rlsn committed Jan 2, 2024
0 parents commit 62ef61f
Show file tree
Hide file tree
Showing 3 changed files with 334 additions and 0 deletions.
2 changes: 2 additions & 0 deletions .gitattributes
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
# Auto detect text files and perform LF normalization
* text=auto
164 changes: 164 additions & 0 deletions tripper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,164 @@
import diffusers
from diffusers import (StableDiffusionPipeline, StableDiffusionImg2ImgPipeline)
import torch
from utils import *
import os,json

class Tripper(object):
def __init__(self, model_file):
txt2img_pipe = StableDiffusionPipeline.from_ckpt(model_file, torch_dtype=torch.float16)
txt2img_pipe.safety_checker = lambda images,**kwargs: (images, [False] * len(images))
img2img_pipe = StableDiffusionImg2ImgPipeline(**txt2img_pipe.components)
img2img_pipe.safety_checker = lambda images,**kwags: (images, [False] * len(images))

self.txt2img_pipe = txt2img_pipe.to('cuda')
self.img2img_pipe = img2img_pipe.to("cuda")
self.loras = dict()

def scheduler(self):
return self.txt2img_pipe.scheduler
def show_schedulers(self):
return self.txt2img_pipe.scheduler.compatibles
def set_scheduler(self, scheduler_cls):
self.txt2img_pipe.scheduler = scheduler_cls.from_config(self.txt2img_pipe.scheduler.config)
self.img2img_pipe.scheduler = scheduler_cls.from_config(self.img2img_pipe.scheduler.config)

def load_lora(self, lora_dict):
for lora in lora_dict:
if lora not in self.loras:
self.txt2img_pipe = load_lora_weights(self.txt2img_pipe, lora, lora_dict[lora], 'cuda', torch.float32, load=True)
self.loras[lora] = lora_dict[lora]
print(f"loaded {lora}")
else:
print(f"already loaded {lora}")
def unload_lora(self, lora_dict):
for lora in lora_dict:
if lora in self.loras:
self.txt2img_pipe = load_lora_weights(self.txt2img_pipe, lora, lora_dict[lora], 'cuda', torch.float32, load=False)
del self.loras[lora]
print(f"unloaded {lora}")
else:
print(f"have not loaded {lora}")

def txt2img(self, prompt, negative_prompt, lora_dict,
width=512, height=768, num_img=6, guidance_scale=7, num_inference_steps=25,
out_dir="preview"):
os.makedirs(out_dir, exist_ok = True)

self.load_lora(lora_dict)

prompt = clean_prompt(prompt)
prompt_embeds, negative_prompt_embeds = convert_prompt_embeds(self.txt2img_pipe, prompt, negative_prompt)
images = self.txt2img_pipe(prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
guidance_scale=guidance_scale,
num_images_per_prompt=num_img,
num_inference_steps=num_inference_steps,
height=height, width=width,
).images
for i,img in enumerate(images):
fn = f"{out_dir}/{timestr()}_{i}.jpg"
img.convert("RGB").save(fn)
self.unload_lora(lora_dict)

return images

def img2img(self, image, prompt, negative_prompt, lora_dict, strength=0.5,
num_img=6, guidance_scale=7, num_inference_steps=25,
out_dir="preview"):
os.makedirs(out_dir, exist_ok = True)

self.load_lora(lora_dict)

prompt = clean_prompt(prompt)
prompt_embeds, negative_prompt_embeds = convert_prompt_embeds(self.txt2img_pipe, prompt, negative_prompt)
images = self.img2img_pipe(prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
image=image,
strength = strength,
num_images_per_prompt=num_img,
guidance_scale=guidance_scale,
num_inference_steps=num_inference_steps).images
for i,img in enumerate(images):
fn = f"{out_dir}/{timestr()}_{i}.jpg"
img.convert("RGB").save(fn)

self.unload_lora(lora_dict)
return images

def generate_video(self, init_image, prompt, negative_prompt,
lora_dict, nsteps, strength_schedule,
transform_fn,
guidance_scale=7,
num_inference_steps=40,
out_dir="preview"):

os.makedirs(out_dir, exist_ok = True)

with open(f"{out_dir}/config.json","w") as fp:
config = {"prompt":prompt,
"negative_prompt":negative_prompt,
"loras":lora_dict,
"guidance_scale":guidance_scale,
"num_inference_steps":num_inference_steps,
}
json.dump(config,fp,indent=4)

images = [init_image]
self.load_lora(lora_dict)


prompt = clean_prompt(prompt)
prompt_embeds, negative_prompt_embeds = convert_prompt_embeds(self.txt2img_pipe, prompt, negative_prompt)
for s in range(nsteps):
print(f"{s}/{nsteps}")
image = transform_fn(images[-1], s)
images += self.img2img_pipe(prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
image=image,
strength = strength_schedule[s],
guidance_scale=guidance_scale,
num_inference_steps=num_inference_steps).images

fn = out_dir+"/%06d.jpg"%s
images[-1].convert("RGB").save(fn)

self.unload_lora(lora_dict)
return images


# def batch_generate(self, general_prompt, character_dict, addition_list, lora_dict,
# negative_prompt, img_per_comb=6, save_dir=".", guidance_scale=7, num_inference_steps=25):
# for character in character_dict:
# try:
# pipeline = load_lora_weights(pipeline, character, 1., 'cuda', torch.float32, load=True)
# print(f"loaded {character}")
# except:
# continue
# for lora in lora_dict:
# try:
# pipeline = load_lora_weights(pipeline, lora, 1., 'cuda', torch.float32, load=True)
# print(f"loaded {lora}")
# except:
# continue
# for addition in addition_list:
# width = lora_dict[lora][1]
# height = lora_dict[lora][2]
# prompt = general_prompt + lora_dict[lora][0] + addition + character_dict[character]
# prompt = clean_prompt(prompt)
# prompt_embeds, negative_prompt_embeds = convert_prompt_embeds(pipeline, prompt, negative_prompt)
# images = txt2img_pipe(prompt_embeds=prompt_embeds,
# negative_prompt_embeds=negative_prompt_embeds,
# guidance_scale=guidance_scale,
# num_images_per_prompt=img_per_comb,
# num_inference_steps=num_inference_steps,
# height=height, width=width,
# ).images
# for img in images:
# fn = f"{save_dir}/{character.split('.')[0]}_{lora.split('.')[0]}_{int(np.random.rand()*1e6)}.jpg"
# img.convert("RGB").save(fn)
# print(f"saved {fn}")
# pipeline = load_lora_weights(pipeline, lora, 1., 'cuda', torch.float32, load=False)
# print(f"unloaded {lora}")
# pipeline = load_lora_weights(pipeline, character, 1., 'cuda', torch.float32, load=False)
# print(f"unloaded {character}")
168 changes: 168 additions & 0 deletions utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,168 @@
from collections import defaultdict
from einops import einsum
import torch
import PIL
from PIL import Image
import numpy as np
import time
from safetensors.torch import load_file
from scipy.stats import norm

def timestr():
return time.strftime('%Y%m%d%H%M%S', time.localtime())

def image_grid(imgs, rows, cols):
assert len(imgs) == rows*cols

w, h = imgs[0].size
grid = Image.new('RGB', size=(cols*w, rows*h))
grid_w, grid_h = grid.size

for i, img in enumerate(imgs):
grid.paste(img, box=(i%cols*w, i//cols*h))
return grid

def export_as_gif(filename, images, frames_per_second=10, rubber_band=False):
if rubber_band:
images += images[2:-1][::-1]
images[0].save(
filename,
save_all=True,
append_images=images[1:],
duration=1000 // frames_per_second,
loop=0,
)

def load_lora_weights(pipeline, checkpoint_path, multiplier, device, dtype, load=True):
LORA_PREFIX_UNET = "lora_unet"
LORA_PREFIX_TEXT_ENCODER = "lora_te"
# load LoRA weight from .safetensors
state_dict = load_file(checkpoint_path, device=device)

updates = defaultdict(dict)
for key, value in state_dict.items():
# it is suggested to print out the key, it usually will be something like below
# "lora_te_text_model_encoder_layers_0_self_attn_k_proj.lora_down.weight"

layer, elem = key.split('.', 1)
updates[layer][elem] = value

# directly update weight in diffusers model
for layer, elems in updates.items():

if "text" in layer:
layer_infos = layer.split(LORA_PREFIX_TEXT_ENCODER + "_")[-1].split("_")
curr_layer = pipeline.text_encoder
else:
layer_infos = layer.split(LORA_PREFIX_UNET + "_")[-1].split("_")
curr_layer = pipeline.unet

# find the target layer
temp_name = layer_infos.pop(0)
while len(layer_infos) > -1:
try:
curr_layer = curr_layer.__getattr__(temp_name)
if len(layer_infos) > 0:
temp_name = layer_infos.pop(0)
elif len(layer_infos) == 0:
break
except Exception:
if len(temp_name) > 0:
temp_name += "_" + layer_infos.pop(0)
else:
temp_name = layer_infos.pop(0)

# get elements for this layer
weight_up = elems['lora_up.weight'].to(dtype)
weight_down = elems['lora_down.weight'].to(dtype)
alpha = elems['alpha']
if alpha:
alpha = alpha.item() / weight_up.shape[1]
else:
alpha = 1.0

# update weight
if len(weight_up.shape) == 4:
x = einsum(weight_up.expand(-1,-1,weight_down.size(2), weight_down.size(3)),weight_down,"c1 k h w, k c2 h w -> c1 c2 h w")
if load:
curr_layer.weight.data += multiplier * alpha * x
# curr_layer.weight.data += multiplier * alpha * torch.mm(weight_up.squeeze(3).squeeze(2), weight_down.squeeze(3).squeeze(2)).unsqueeze(2).unsqueeze(3)
else:
curr_layer.weight.data -= multiplier * alpha * x
# curr_layer.weight.data -= multiplier * alpha * torch.mm(weight_up.squeeze(3).squeeze(2), weight_down.squeeze(3).squeeze(2)).unsqueeze(2).unsqueeze(3)
else:
if load:
curr_layer.weight.data += multiplier * alpha * torch.mm(weight_up, weight_down)
else:
curr_layer.weight.data -= multiplier * alpha * torch.mm(weight_up, weight_down)

return pipeline


def convert_prompt_embeds(pipe, prompt,negative_prompt):
max_length = pipe.tokenizer.model_max_length

input_ids = pipe.tokenizer(prompt, return_tensors="pt").input_ids

negative_ids = pipe.tokenizer(negative_prompt, return_tensors="pt").input_ids
if input_ids.shape[-1]>negative_ids.shape[-1]:
negative_ids = pipe.tokenizer(negative_prompt, truncation=False, padding="max_length", max_length=input_ids.shape[-1], return_tensors="pt").input_ids
elif input_ids.shape[-1]<negative_ids.shape[-1]:
input_ids = pipe.tokenizer(prompt, truncation=False, padding="max_length", max_length=negative_ids.shape[-1], return_tensors="pt").input_ids

input_ids = input_ids.to("cuda")
negative_ids = negative_ids.to("cuda")

concat_embeds = []
neg_embeds = []
for i in range(0, input_ids.shape[-1], max_length):
concat_embeds.append(pipe.text_encoder(input_ids[:, i: i + max_length])[0])
neg_embeds.append(pipe.text_encoder(negative_ids[:, i: i + max_length])[0])

prompt_embeds = torch.cat(concat_embeds, dim=1)
negative_prompt_embeds = torch.cat(neg_embeds, dim=1)
return prompt_embeds, negative_prompt_embeds


def clean_prompt(s):
tokens = s.split(',')
clean = []
for t in tokens:
t = t.strip()
if t.startswith("<lora"):
continue
if t not in clean:
clean.append(t)
s = ", ".join(clean)
return s

def cos_schedule(low,high,phase,cycles,steps):
amp = (high-low)/2
offset = (high+low)/2
return offset+np.cos(np.linspace(phase,phase+2*np.pi*cycles,steps))*amp

def const_schedule(v,steps):
return np.array([v]*steps)

def zoom(im,ratio):
if ratio<1:
s = im.size
w,h = im.size[0]*ratio, im.size[1]*ratio
m = (s[0]-w)/2,(s[1]-h)/2
nim = im.crop((m[0], m[1], s[0]-m[0], s[1]-m[1]))
return nim.resize(s)
else:
# todo
return im


def impulse_schedule(floor,ceiling,impulse,width,steps):
x = np.arange(steps)
Y=[]
for imp in impulse:
y = norm.pdf(x,imp, width)
y*=(ceiling-floor)/y.max()
Y.append(y)
Y=np.array(Y).sum(0)
print(Y.shape)
return Y+floor

0 comments on commit 62ef61f

Please sign in to comment.