diff --git a/sd-webui-hunyuan-dit/LICENSE b/sd-webui-hunyuan-dit/LICENSE new file mode 100755 index 0000000..3ff4b08 --- /dev/null +++ b/sd-webui-hunyuan-dit/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2024 sethgggg + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. \ No newline at end of file diff --git a/sd-webui-hunyuan-dit/README.md b/sd-webui-hunyuan-dit/README.md new file mode 100755 index 0000000..f23a554 --- /dev/null +++ b/sd-webui-hunyuan-dit/README.md @@ -0,0 +1,68 @@ +# Hunyuan extension for sd-webui + +The extension helps you to use [Hunyuan DiT Model](https://github.com/Tencent/HunyuanDiT) in [Stable-diffusion-webui](https://github.com/AUTOMATIC1111/stable-diffusion-webui): + +### Features + +- Core + - [x] [Txt2Img] + - [x] [Img2Img] + - [] [LORA] + - [] [ControlNet] + - [] [HiresUpscaler] +- Advanced + - [] [MultiDiffusion] + - [] [Adetailer] + +### Installation + +1. You can install this extension via the webui extension downloader by copying the git repository ```https://github.com/sethgggg/sd-webui-hunyuan-dit.git```. + +![install](examples/20240709-000053.jpg) + +2. Download the HunyuanDiT model from [Huggingface](https://huggingface.co/Tencent-Hunyuan/HunyuanDiT-v1.2-Diffusers) to local storage, the default storage location is in ```models/hunyuan``` of webui folder. You can change the default storage location via the settings card of the webui. + +![folder](examples/20240708-235015.jpg) + +![settings](examples/20240708-235001.jpg) + +3. You have to place the transformer model in ```models/Stable-Diffusion```, which is the main storage location of checkpoints. If you have fine-tuned a new model, you can also place the transformer model in the same folder and then you could select the model here. + +4. Find the HunyuanDiT card and enable them, if you want to use stable diffusion models, remember to disable the HunyuanDiT model. + +![enable](examples/20240708-235013.jpg) + +5. This project is use the diffusers as inference backend, thus we support the following samplers: + +| Sampler Name | Sampler Instance in diffusers | +|-------------------------|------------------------------------------------------------------------------| +| Euler a | EulerAncestralDiscreteScheduler() | +| Euler | EulerDiscreteScheduler() | +| LMS | LMSDiscreteScheduler() | +| Heun | HeunDiscreteScheduler() | +| DPM2 | KDPM2DiscreteScheduler() | +| DPM2 a | KDPM2AncestralDiscreteScheduler() | +| DPM++ SDE | DPMSolverSinglestepScheduler() | +| DPM++ 2M | DPMSolverMultistepScheduler() | +| DPM++ 2S a | DPMSolverSinglestepScheduler() | +| LMS Karras | LMSDiscreteScheduler(use_karras_sigmas=True) | +| DPM2 Karras | KDPM2DiscreteScheduler(use_karras_sigmas=True) | +| DPM2 a Karras | KDPM2AncestralDiscreteScheduler(use_karras_sigmas=True) | +| DPM++ SDE Karras | DPMSolverSinglestepScheduler(use_karras_sigmas=True) | +| DPM++ 2M Karras | DPMSolverMultistepScheduler(use_karras_sigmas=True) | +| DPM++ 2S a Karras | DPMSolverSinglestepScheduler(use_karras_sigmas=True) | +| DDIM | DDIMScheduler() | +| UniPC | UniPCMultistepScheduler() | +| DPM++ 2M SDE Karras | DPMSolverMultistepScheduler(use_karras_sigmas=True, algorithm_type="sde-dpmsolver++") | +| DPM++ 2M SDE | DPMSolverMultistepScheduler(algorithm_type="sde-dpmsolver++") | +| LCM | LCMScheduler() | + +### Examples + +⚪ Txt2img: generating images, you can use the webui style prompts to generate + +![txt2img](examples/20240708-235005.jpg) + +⚪ Img2img: given a image, you can use the Hunyuan DiT model to generate more images. + +![img2img](examples/20240708-234944.jpg) \ No newline at end of file diff --git a/sd-webui-hunyuan-dit/examples/20240708-234944.jpg b/sd-webui-hunyuan-dit/examples/20240708-234944.jpg new file mode 100644 index 0000000..7fe7b05 Binary files /dev/null and b/sd-webui-hunyuan-dit/examples/20240708-234944.jpg differ diff --git a/sd-webui-hunyuan-dit/examples/20240708-235001.jpg b/sd-webui-hunyuan-dit/examples/20240708-235001.jpg new file mode 100644 index 0000000..8ee590b Binary files /dev/null and b/sd-webui-hunyuan-dit/examples/20240708-235001.jpg differ diff --git a/sd-webui-hunyuan-dit/examples/20240708-235005.jpg b/sd-webui-hunyuan-dit/examples/20240708-235005.jpg new file mode 100644 index 0000000..25bd3d2 Binary files /dev/null and b/sd-webui-hunyuan-dit/examples/20240708-235005.jpg differ diff --git a/sd-webui-hunyuan-dit/examples/20240708-235013.jpg b/sd-webui-hunyuan-dit/examples/20240708-235013.jpg new file mode 100644 index 0000000..ef12254 Binary files /dev/null and b/sd-webui-hunyuan-dit/examples/20240708-235013.jpg differ diff --git a/sd-webui-hunyuan-dit/examples/20240708-235015.jpg b/sd-webui-hunyuan-dit/examples/20240708-235015.jpg new file mode 100644 index 0000000..4a7d31f Binary files /dev/null and b/sd-webui-hunyuan-dit/examples/20240708-235015.jpg differ diff --git a/sd-webui-hunyuan-dit/examples/20240709-000053.jpg b/sd-webui-hunyuan-dit/examples/20240709-000053.jpg new file mode 100644 index 0000000..05c8aaf Binary files /dev/null and b/sd-webui-hunyuan-dit/examples/20240709-000053.jpg differ diff --git a/sd-webui-hunyuan-dit/hunyuan_utils/diffusers_learned_conditioning.py b/sd-webui-hunyuan-dit/hunyuan_utils/diffusers_learned_conditioning.py new file mode 100755 index 0000000..1cf8756 --- /dev/null +++ b/sd-webui-hunyuan-dit/hunyuan_utils/diffusers_learned_conditioning.py @@ -0,0 +1,6 @@ +from modules import prompt_parser, shared + +def get_learned_conditioning_hunyuan(batch: prompt_parser.SdConditioning | list[str]): + clip_l_conds, clip_l_attention = shared.clip_l_model(batch) + t5_conds, t5_attention = shared.mt5_model(batch) + return {"crossattn":clip_l_conds, "mask":clip_l_attention, "crossattn_2":t5_conds, "mask_2":t5_attention} \ No newline at end of file diff --git a/sd-webui-hunyuan-dit/hunyuan_utils/sd_hijack_clip_diffusers.py b/sd-webui-hunyuan-dit/hunyuan_utils/sd_hijack_clip_diffusers.py new file mode 100755 index 0000000..f8b3de9 --- /dev/null +++ b/sd-webui-hunyuan-dit/hunyuan_utils/sd_hijack_clip_diffusers.py @@ -0,0 +1,727 @@ +import math +from collections import namedtuple + +import torch + +from modules import prompt_parser, devices, sd_hijack, sd_emphasis +from modules.shared import opts + + +class PromptChunk: + """ + This object contains token ids, weight (multipliers:1.4) and textual inversion embedding info for a chunk of prompt. + If a prompt is short, it is represented by one PromptChunk, otherwise, multiple are necessary. + Each PromptChunk contains an exact amount of tokens - 77, which includes one for start and end token, + so just 75 tokens from prompt. + """ + + def __init__(self): + self.tokens = [] + self.multipliers = [] + self.fixes = [] + + +PromptChunkFix = namedtuple('PromptChunkFix', ['offset', 'embedding']) +"""An object of this type is a marker showing that textual inversion embedding's vectors have to placed at offset in the prompt +chunk. Thos objects are found in PromptChunk.fixes and, are placed into FrozenCLIPEmbedderWithCustomWordsBase.hijack.fixes, and finally +are applied by sd_hijack.EmbeddingsWithFixes's forward function.""" + + +class FrozenCLIPEmbedderWithCustomWordsBase(torch.nn.Module): + """A pytorch module that is a wrapper for FrozenCLIPEmbedder module. it enhances FrozenCLIPEmbedder, making it possible to + have unlimited prompt length and assign weights to tokens in prompt. + """ + + def __init__(self, wrapped, hijack): + super().__init__() + + self.wrapped = wrapped + """Original FrozenCLIPEmbedder module; can also be FrozenOpenCLIPEmbedder or xlmr.BertSeriesModelWithTransformation, + depending on model.""" + + self.hijack: sd_hijack.StableDiffusionModelHijack = hijack + self.chunk_length = 75 + + self.is_trainable = getattr(wrapped, 'is_trainable', False) + self.input_key = getattr(wrapped, 'input_key', 'txt') + self.legacy_ucg_val = None + + def empty_chunk(self): + """creates an empty PromptChunk and returns it""" + + chunk = PromptChunk() + chunk.tokens = [self.id_start] + [self.id_end] * (self.chunk_length + 1) + chunk.multipliers = [1.0] * (self.chunk_length + 2) + return chunk + + def get_target_prompt_token_count(self, token_count): + """returns the maximum number of tokens a prompt of a known length can have before it requires one more PromptChunk to be represented""" + + return math.ceil(max(token_count, 1) / self.chunk_length) * self.chunk_length + + def tokenize(self, texts): + """Converts a batch of texts into a batch of token ids""" + + raise NotImplementedError + + def encode_with_transformers(self, tokens): + """ + converts a batch of token ids (in python lists) into a single tensor with numeric respresentation of those tokens; + All python lists with tokens are assumed to have same length, usually 77. + if input is a list with B elements and each element has T tokens, expected output shape is (B, T, C), where C depends on + model - can be 768 and 1024. + Among other things, this call will read self.hijack.fixes, apply it to its inputs, and clear it (setting it to None). + """ + + raise NotImplementedError + + def encode_embedding_init_text(self, init_text, nvpt): + """Converts text into a tensor with this text's tokens' embeddings. Note that those are embeddings before they are passed through + transformers. nvpt is used as a maximum length in tokens. If text produces less teokens than nvpt, only this many is returned.""" + + raise NotImplementedError + + def tokenize_line(self, line): + """ + this transforms a single prompt into a list of PromptChunk objects - as many as needed to + represent the prompt. + Returns the list and the total number of tokens in the prompt. + """ + + if opts.emphasis != "None": + parsed = prompt_parser.parse_prompt_attention(line) + else: + parsed = [[line, 1.0]] + + tokenized = self.tokenize([text for text, _ in parsed]) + + chunks = [] + chunk = PromptChunk() + token_count = 0 + last_comma = -1 + + def next_chunk(is_last=False): + """puts current chunk into the list of results and produces the next one - empty; + if is_last is true, tokens tokens at the end won't add to token_count""" + nonlocal token_count + nonlocal last_comma + nonlocal chunk + + if is_last: + token_count += len(chunk.tokens) + else: + token_count += self.chunk_length + + to_add = self.chunk_length - len(chunk.tokens) + if to_add > 0: + chunk.tokens += [self.id_end] * to_add + chunk.multipliers += [1.0] * to_add + + chunk.tokens = [self.id_start] + chunk.tokens + [self.id_end] + chunk.multipliers = [1.0] + chunk.multipliers + [1.0] + + last_comma = -1 + chunks.append(chunk) + chunk = PromptChunk() + + for tokens, (text, weight) in zip(tokenized, parsed): + if text == 'BREAK' and weight == -1: + next_chunk() + continue + + position = 0 + while position < len(tokens): + token = tokens[position] + + if token == self.comma_token: + last_comma = len(chunk.tokens) + + # this is when we are at the end of alloted 75 tokens for the current chunk, and the current token is not a comma. opts.comma_padding_backtrack + # is a setting that specifies that if there is a comma nearby, the text after the comma should be moved out of this chunk and into the next. + elif opts.comma_padding_backtrack != 0 and len(chunk.tokens) == self.chunk_length and last_comma != -1 and len(chunk.tokens) - last_comma <= opts.comma_padding_backtrack: + break_location = last_comma + 1 + + reloc_tokens = chunk.tokens[break_location:] + reloc_mults = chunk.multipliers[break_location:] + + chunk.tokens = chunk.tokens[:break_location] + chunk.multipliers = chunk.multipliers[:break_location] + + next_chunk() + chunk.tokens = reloc_tokens + chunk.multipliers = reloc_mults + + if len(chunk.tokens) == self.chunk_length: + next_chunk() + + embedding, embedding_length_in_tokens = self.hijack.embedding_db.find_embedding_at_position(tokens, position) + if embedding is None: + chunk.tokens.append(token) + chunk.multipliers.append(weight) + position += 1 + continue + + emb_len = int(embedding.vectors) + if len(chunk.tokens) + emb_len > self.chunk_length: + next_chunk() + + chunk.fixes.append(PromptChunkFix(len(chunk.tokens), embedding)) + + chunk.tokens += [0] * emb_len + chunk.multipliers += [weight] * emb_len + position += embedding_length_in_tokens + + if chunk.tokens or not chunks: + next_chunk(is_last=True) + + return chunks, token_count + + def process_texts(self, texts): + """ + Accepts a list of texts and calls tokenize_line() on each, with cache. Returns the list of results and maximum + length, in tokens, of all texts. + """ + + token_count = 0 + + cache = {} + batch_chunks = [] + for line in texts: + if line in cache: + chunks = cache[line] + else: + chunks, current_token_count = self.tokenize_line(line) + token_count = max(current_token_count, token_count) + + cache[line] = chunks + + batch_chunks.append(chunks) + + return batch_chunks, token_count + + def forward(self, texts): + """ + Accepts an array of texts; Passes texts through transformers network to create a tensor with numerical representation of those texts. + Returns a tensor with shape of (B, T, C), where B is length of the array; T is length, in tokens, of texts (including padding) - T will + be a multiple of 77; and C is dimensionality of each token - for SD1 it's 768, for SD2 it's 1024, and for SDXL it's 1280. + An example shape returned by this function can be: (2, 77, 768). + For SDXL, instead of returning one tensor avobe, it returns a tuple with two: the other one with shape (B, 1280) with pooled values. + Webui usually sends just one text at a time through this function - the only time when texts is an array with more than one elemenet + is when you do prompt editing: "a picture of a [cat:dog:0.4] eating ice cream" + """ + + if opts.use_old_emphasis_implementation: + import modules.sd_hijack_clip_old + return modules.sd_hijack_clip_old.forward_old(self, texts) + + batch_chunks, token_count = self.process_texts(texts) + + used_embeddings = {} + chunk_count = max([len(x) for x in batch_chunks]) + + zs = [] + tk = [] + for i in range(chunk_count): + batch_chunk = [chunks[i] if i < len(chunks) else self.empty_chunk() for chunks in batch_chunks] + + tokens = [x.tokens for x in batch_chunk] + attn_mask = [] + for token in tokens: + temp_mask = [] + for token_id in token: + if token_id != self.id_end: + temp_mask.append(1) + else: + temp_mask.append(0) + attn_mask.append(temp_mask) + multipliers = [x.multipliers for x in batch_chunk] + self.hijack.fixes = [x.fixes for x in batch_chunk] + + for fixes in self.hijack.fixes: + for _position, embedding in fixes: + used_embeddings[embedding.name] = embedding + + z = self.process_tokens(tokens, multipliers) + zs.append(z) + tk.append(torch.tensor(attn_mask)) + + if opts.textual_inversion_add_hashes_to_infotext and used_embeddings: + hashes = [] + for name, embedding in used_embeddings.items(): + shorthash = embedding.shorthash + if not shorthash: + continue + + name = name.replace(":", "").replace(",", "") + hashes.append(f"{name}: {shorthash}") + + if hashes: + if self.hijack.extra_generation_params.get("TI hashes"): + hashes.append(self.hijack.extra_generation_params.get("TI hashes")) + self.hijack.extra_generation_params["TI hashes"] = ", ".join(hashes) + + if any(x for x in texts if "(" in x or "[" in x) and opts.emphasis != "Original": + self.hijack.extra_generation_params["Emphasis"] = opts.emphasis + + if getattr(self, 'return_pooled', False): + return torch.hstack(zs), zs[0].pooled + elif getattr(self, 'return_masks', False): + return torch.hstack(zs), torch.hstack(tk) + else: + return torch.hstack(zs) + + def process_tokens(self, remade_batch_tokens, batch_multipliers): + """ + sends one single prompt chunk to be encoded by transformers neural network. + remade_batch_tokens is a batch of tokens - a list, where every element is a list of tokens; usually + there are exactly 77 tokens in the list. batch_multipliers is the same but for multipliers instead of tokens. + Multipliers are used to give more or less weight to the outputs of transformers network. Each multiplier + corresponds to one token. + """ + tokens = torch.asarray(remade_batch_tokens).to(devices.device) + + # this is for SD2: SD1 uses the same token for padding and end of text, while SD2 uses different ones. + if self.id_end != self.id_pad: + for batch_pos in range(len(remade_batch_tokens)): + index = remade_batch_tokens[batch_pos].index(self.id_end) + tokens[batch_pos, index+1:tokens.shape[1]] = self.id_pad + + z = self.encode_with_transformers(tokens) + + pooled = getattr(z, 'pooled', None) + + emphasis = sd_emphasis.get_current_option(opts.emphasis)() + emphasis.tokens = remade_batch_tokens + emphasis.multipliers = torch.asarray(batch_multipliers).to(devices.device) + emphasis.z = z + + emphasis.after_transformers() + + z = emphasis.z + + if pooled is not None: + z.pooled = pooled + + return z + + +class FrozenCLIPEmbedderWithCustomWords(FrozenCLIPEmbedderWithCustomWordsBase): + def __init__(self, wrapped, hijack): + super().__init__(wrapped, hijack) + self.tokenizer = wrapped.tokenizer + + vocab = self.tokenizer.get_vocab() + + self.comma_token = vocab.get(',', None) + + self.token_mults = {} + tokens_with_parens = [(k, v) for k, v in vocab.items() if '(' in k or ')' in k or '[' in k or ']' in k] + for text, ident in tokens_with_parens: + mult = 1.0 + for c in text: + if c == '[': + mult /= 1.1 + if c == ']': + mult *= 1.1 + if c == '(': + mult *= 1.1 + if c == ')': + mult /= 1.1 + + if mult != 1.0: + self.token_mults[ident] = mult + + self.id_start = self.wrapped.tokenizer.bos_token_id + self.id_end = self.wrapped.tokenizer.eos_token_id + self.id_pad = self.id_end + + def tokenize(self, texts): + tokenized = self.wrapped.tokenizer(texts, truncation=False, add_special_tokens=False)["input_ids"] + + return tokenized + + def encode_with_transformers(self, tokens): + outputs = self.wrapped(input_ids=tokens, output_hidden_states=-opts.CLIP_stop_at_last_layers) + + if opts.CLIP_stop_at_last_layers > 1: + z = outputs.hidden_states[-opts.CLIP_stop_at_last_layers] + else: + z = outputs.last_hidden_state + z.pooled = outputs.text_embeds + + return z + + def encode_embedding_init_text(self, init_text, nvpt): + embedding_layer = self.wrapped.text_model.embeddings + ids = self.wrapped.tokenizer(init_text, max_length=nvpt, return_tensors="pt", add_special_tokens=False)["input_ids"] + embedded = embedding_layer.token_embedding(ids.to(embedding_layer.token_embedding.weight.device)).squeeze(0) + + return embedded + +class FrozenBertEmbedderWithCustomWords(FrozenCLIPEmbedderWithCustomWordsBase): + def __init__(self, wrapped, hijack): + super().__init__(wrapped, hijack) + self.tokenizer = wrapped.tokenizer + + vocab = self.tokenizer.get_vocab() + + self.comma_token = vocab.get(',', None) + + self.token_mults = {} + tokens_with_parens = [(k, v) for k, v in vocab.items() if '(' in k or ')' in k or '[' in k or ']' in k] + for text, ident in tokens_with_parens: + mult = 1.0 + for c in text: + if c == '[': + mult /= 1.1 + if c == ']': + mult *= 1.1 + if c == '(': + mult *= 1.1 + if c == ')': + mult /= 1.1 + + if mult != 1.0: + self.token_mults[ident] = mult + + self.id_start = self.wrapped.tokenizer.cls_token_id + self.id_end = self.wrapped.tokenizer.sep_token_id + self.id_pad = self.wrapped.tokenizer.pad_token_id + + def empty_chunk(self): + """creates an empty PromptChunk and returns it""" + + chunk = PromptChunk() + chunk.tokens = [self.id_start] + [self.id_end] + [self.id_pad] * (self.chunk_length) + chunk.multipliers = [1.0] * (self.chunk_length + 2) + return chunk + + def tokenize_line(self, line): + """ + this transforms a single prompt into a list of PromptChunk objects - as many as needed to + represent the prompt. + Returns the list and the total number of tokens in the prompt. + """ + + if opts.emphasis != "None": + parsed = prompt_parser.parse_prompt_attention(line) + else: + parsed = [[line, 1.0]] + + tokenized = self.tokenize([text for text, _ in parsed]) + + chunks = [] + chunk = PromptChunk() + token_count = 0 + last_comma = -1 + + def next_chunk(is_last=False): + """puts current chunk into the list of results and produces the next one - empty; + if is_last is true, tokens tokens at the end won't add to token_count""" + nonlocal token_count + nonlocal last_comma + nonlocal chunk + + if is_last: + token_count += len(chunk.tokens) + else: + token_count += self.chunk_length + + to_add = self.chunk_length - len(chunk.tokens) + if to_add > 0: + chunk.tokens += [self.id_end] + [self.id_pad] * to_add + chunk.multipliers += [1.0] * to_add + else: + chunk.tokens += [self.id_end] + + chunk.tokens = [self.id_start] + chunk.tokens + chunk.multipliers = [1.0] + chunk.multipliers + [1.0] + + last_comma = -1 + chunks.append(chunk) + chunk = PromptChunk() + + for tokens, (text, weight) in zip(tokenized, parsed): + if text == 'BREAK' and weight == -1: + next_chunk() + continue + + position = 0 + while position < len(tokens): + token = tokens[position] + + if token == self.comma_token: + last_comma = len(chunk.tokens) + + # this is when we are at the end of alloted 75 tokens for the current chunk, and the current token is not a comma. opts.comma_padding_backtrack + # is a setting that specifies that if there is a comma nearby, the text after the comma should be moved out of this chunk and into the next. + elif opts.comma_padding_backtrack != 0 and len(chunk.tokens) == self.chunk_length and last_comma != -1 and len(chunk.tokens) - last_comma <= opts.comma_padding_backtrack: + break_location = last_comma + 1 + + reloc_tokens = chunk.tokens[break_location:] + reloc_mults = chunk.multipliers[break_location:] + + chunk.tokens = chunk.tokens[:break_location] + chunk.multipliers = chunk.multipliers[:break_location] + + next_chunk() + chunk.tokens = reloc_tokens + chunk.multipliers = reloc_mults + + if len(chunk.tokens) == self.chunk_length: + next_chunk() + + embedding, embedding_length_in_tokens = self.hijack.embedding_db.find_embedding_at_position(tokens, position) + if embedding is None: + chunk.tokens.append(token) + chunk.multipliers.append(weight) + position += 1 + continue + + emb_len = int(embedding.vectors) + if len(chunk.tokens) + emb_len > self.chunk_length: + next_chunk() + + chunk.fixes.append(PromptChunkFix(len(chunk.tokens), embedding)) + + chunk.tokens += [0] * emb_len + chunk.multipliers += [weight] * emb_len + position += embedding_length_in_tokens + + if chunk.tokens or not chunks: + next_chunk(is_last=True) + + return chunks, token_count + + def tokenize(self, texts): + tokenized = self.wrapped.tokenizer(texts, truncation=False, add_special_tokens=False)["input_ids"] + + return tokenized + + def encode_with_transformers(self, tokens): + attn_mask = [] + for token in tokens: + temp_mask = [] + for token_id in token: + if token_id != self.id_pad: + temp_mask.append(1) + else: + temp_mask.append(0) + attn_mask.append(temp_mask) + outputs = self.wrapped(input_ids=tokens,attention_mask=torch.tensor(attn_mask).to(devices.device),output_hidden_states=-opts.CLIP_stop_at_last_layers) + + if opts.CLIP_stop_at_last_layers > 1: + z = outputs.hidden_states[-opts.CLIP_stop_at_last_layers] + else: + z = outputs.last_hidden_state + + return z + + def encode_embedding_init_text(self, init_text, nvpt): + embedding_layer = self.wrapped.text_model.embeddings + ids = self.wrapped.tokenizer(init_text, max_length=nvpt, return_tensors="pt", add_special_tokens=False)["input_ids"] + embedded = embedding_layer.token_embedding.wrapped(ids.to(embedding_layer.token_embedding.wrapped.weight.device)).squeeze(0) + + return embedded + + def forward(self, texts): + """ + Accepts an array of texts; Passes texts through transformers network to create a tensor with numerical representation of those texts. + Returns a tensor with shape of (B, T, C), where B is length of the array; T is length, in tokens, of texts (including padding) - T will + be a multiple of 77; and C is dimensionality of each token - for SD1 it's 768, for SD2 it's 1024, and for SDXL it's 1280. + An example shape returned by this function can be: (2, 77, 768). + For SDXL, instead of returning one tensor avobe, it returns a tuple with two: the other one with shape (B, 1280) with pooled values. + Webui usually sends just one text at a time through this function - the only time when texts is an array with more than one elemenet + is when you do prompt editing: "a picture of a [cat:dog:0.4] eating ice cream" + """ + + batch_chunks, token_count = self.process_texts(texts) + + chunk_count = max([len(x) for x in batch_chunks]) + + zs = [] + tk = [] + for i in range(chunk_count): + batch_chunk = [chunks[i] if i < len(chunks) else self.empty_chunk() for chunks in batch_chunks] + + tokens = [x.tokens for x in batch_chunk] + attn_mask = [] + for token in tokens: + temp_mask = [] + for token_id in token: + if token_id != self.id_pad: + temp_mask.append(1) + else: + temp_mask.append(0) + attn_mask.append(temp_mask) + multipliers = [x.multipliers for x in batch_chunk] + + z = self.process_tokens(tokens, multipliers) + zs.append(z) + tk.append(torch.tensor(attn_mask)) + + if getattr(self, 'return_masks', False): + return torch.hstack([zs[0]]), torch.hstack([tk[0]]).to(devices.device) + else: + return torch.hstack([zs[0]]) + +class FrozenT5EmbedderWithCustomWords(FrozenCLIPEmbedderWithCustomWords): + def __init__(self, wrapped, hijack): + super().__init__(wrapped, hijack) + + self.tokenizer = wrapped.tokenizer + self.chunk_length = 255 + self.id_start = self.tokenizer.bos_token_id + self.id_end = self.tokenizer.eos_token_id + self.id_pad = 0 + + def empty_chunk(self): + """creates an empty PromptChunk and returns it""" + + chunk = PromptChunk() + chunk.tokens = [self.id_end] + [self.id_pad] * self.chunk_length + chunk.multipliers = [1.0] + [1.0]* self.chunk_length + return chunk + + def tokenize_line(self, line): + """ + this transforms a single prompt into a list of PromptChunk objects - as many as needed to + represent the prompt. + Returns the list and the total number of tokens in the prompt. + """ + + parsed = prompt_parser.parse_prompt_attention(line) + + tokenized = self.tokenize([text for text, _ in parsed]) + + chunks = [] + chunk = PromptChunk() + token_count = 0 + last_comma = -1 + + def next_chunk(is_last=False): + """puts current chunk into the list of results and produces the next one - empty; + if is_last is true, tokens tokens at the end won't add to token_count""" + nonlocal token_count + nonlocal last_comma + nonlocal chunk + + if is_last: + token_count += len(chunk.tokens) + else: + token_count += self.chunk_length + + to_add = self.chunk_length - len(chunk.tokens) + if to_add > 0: + chunk.tokens += [self.id_end] + [self.id_pad] * to_add + chunk.multipliers += [1.0] * to_add + else: + chunk.tokens += [self.id_end] + + chunk.tokens = [] + chunk.tokens + chunk.multipliers = [] + chunk.multipliers + [1.0] + + last_comma = -1 + chunks.append(chunk) + chunk = PromptChunk() + + for tokens, (text, weight) in zip(tokenized, parsed): + if text == 'BREAK' and weight == -1: + next_chunk() + continue + + position = 0 + while position < len(tokens): + token = tokens[position] + + if token == self.comma_token: + last_comma = len(chunk.tokens) + + # this is when we are at the end of alloted 75 tokens for the current chunk, and the current token is not a comma. opts.comma_padding_backtrack + # is a setting that specifies that if there is a comma nearby, the text after the comma should be moved out of this chunk and into the next. + elif len(chunk.tokens) == self.chunk_length and last_comma != -1 and len(chunk.tokens) - last_comma <= 20: + break_location = last_comma + 1 + + reloc_tokens = chunk.tokens[break_location:] + reloc_mults = chunk.multipliers[break_location:] + + chunk.tokens = chunk.tokens[:break_location] + chunk.multipliers = chunk.multipliers[:break_location] + + next_chunk() + chunk.tokens = reloc_tokens + chunk.multipliers = reloc_mults + + if len(chunk.tokens) == self.chunk_length: + next_chunk() + + chunk.tokens.append(token) + chunk.multipliers.append(weight) + position += 1 + + if chunk.tokens or not chunks: + next_chunk(is_last=True) + + return chunks, token_count + + def encode_with_transformers(self, tokens): + attn_mask = [] + for token in tokens: + temp_mask = [] + for token_id in token: + if token_id != self.id_pad: + temp_mask.append(1) + else: + temp_mask.append(0) + attn_mask.append(temp_mask) + outputs = self.wrapped(input_ids=tokens, attention_mask=torch.tensor(attn_mask).to(devices.device), output_hidden_states=True) + + ''' + if self.wrapped.layer == "last": + z = outputs.last_hidden_state + else: + z = outputs.hidden_states[self.wrapped.layer_idx] + ''' + z = outputs.last_hidden_state + return z + + def forward(self, texts): + """ + Accepts an array of texts; Passes texts through transformers network to create a tensor with numerical representation of those texts. + Returns a tensor with shape of (B, T, C), where B is length of the array; T is length, in tokens, of texts (including padding) - T will + be a multiple of 77; and C is dimensionality of each token - for SD1 it's 768, for SD2 it's 1024, and for SDXL it's 1280. + An example shape returned by this function can be: (2, 77, 768). + For SDXL, instead of returning one tensor avobe, it returns a tuple with two: the other one with shape (B, 1280) with pooled values. + Webui usually sends just one text at a time through this function - the only time when texts is an array with more than one elemenet + is when you do prompt editing: "a picture of a [cat:dog:0.4] eating ice cream" + """ + + batch_chunks, token_count = self.process_texts(texts) + + chunk_count = max([len(x) for x in batch_chunks]) + + zs = [] + tk = [] + for i in range(chunk_count): + batch_chunk = [chunks[i] if i < len(chunks) else self.empty_chunk() for chunks in batch_chunks] + + tokens = [x.tokens for x in batch_chunk] + attn_mask = [] + for token in tokens: + temp_mask = [] + for token_id in token: + if token_id != self.id_pad: + temp_mask.append(1) + else: + temp_mask.append(0) + attn_mask.append(temp_mask) + multipliers = [x.multipliers for x in batch_chunk] + + z = self.process_tokens(tokens, multipliers) + zs.append(z) + tk.append(torch.tensor(attn_mask)) + + if getattr(self, 'return_masks', False): + return torch.hstack([zs[0]]), torch.hstack([tk[0]]).to(devices.device) + else: + return torch.hstack([zs[0]]) \ No newline at end of file diff --git a/sd-webui-hunyuan-dit/hunyuan_utils/utils.py b/sd-webui-hunyuan-dit/hunyuan_utils/utils.py new file mode 100755 index 0000000..2d9fdbc --- /dev/null +++ b/sd-webui-hunyuan-dit/hunyuan_utils/utils.py @@ -0,0 +1,533 @@ +from modules import devices, rng, shared +import numpy as np +import gc +import inspect +import torch +from typing import Any, Dict, List, Optional, Union, Tuple +from diffusers.schedulers import ( + DDIMScheduler, + DDPMScheduler, + LMSDiscreteScheduler, + EulerDiscreteScheduler, + HeunDiscreteScheduler, + EulerAncestralDiscreteScheduler, + DPMSolverMultistepScheduler, + DPMSolverSinglestepScheduler, + KDPM2DiscreteScheduler, + KDPM2AncestralDiscreteScheduler, + UniPCMultistepScheduler, + LCMScheduler, +) + +hunyuan_transformer_config_v12 = { + "_class_name": "HunyuanDiT2DModel", + "_diffusers_version": "0.30.0.dev0", + "activation_fn": "gelu-approximate", + "attention_head_dim": 88, + "cross_attention_dim": 1024, + "cross_attention_dim_t5": 2048, + "hidden_size": 1408, + "in_channels": 4, + "learn_sigma": True, + "mlp_ratio": 4.3637, + "norm_type": "layer_norm", + "num_attention_heads": 16, + "num_layers": 40, + "patch_size": 2, + "pooled_projection_dim": 1024, + "sample_size": 128, + "text_len": 77, + "text_len_t5": 256, + "use_style_cond_and_image_meta_size": False +} + +dit_sampler_dict = { + "Euler a":EulerAncestralDiscreteScheduler(), + "Euler":EulerDiscreteScheduler(), + "LMS":LMSDiscreteScheduler(), + "Heun":HeunDiscreteScheduler(), + "DPM2":KDPM2DiscreteScheduler(), + "DPM2 a":KDPM2AncestralDiscreteScheduler(), + "DPM++ SDE":DPMSolverSinglestepScheduler(), + "DPM++ 2M":DPMSolverMultistepScheduler(), + "DPM++ 2S a":DPMSolverSinglestepScheduler(), + "LMS Karras":LMSDiscreteScheduler(use_karras_sigmas=True), + "DPM2 Karras":KDPM2DiscreteScheduler(use_karras_sigmas=True), + "DPM2 a Karras":KDPM2AncestralDiscreteScheduler(use_karras_sigmas=True), + "DPM++ SDE Karras":DPMSolverSinglestepScheduler(use_karras_sigmas=True), + "DPM++ 2M Karras":DPMSolverMultistepScheduler(use_karras_sigmas=True), + "DPM++ 2S a Karras":DPMSolverSinglestepScheduler(use_karras_sigmas=True), + "DDIM":DDIMScheduler(), + "UniPC":UniPCMultistepScheduler(), + "DPM++ 2M SDE Karras":DPMSolverMultistepScheduler(use_karras_sigmas=True,algorithm_type="sde-dpmsolver++"), + "DPM++ 2M SDE":DPMSolverMultistepScheduler(algorithm_type="sde-dpmsolver++"), + "LCM":LCMScheduler() +} + +def get_resize_crop_region_for_grid(src, tgt_size): + th = tw = tgt_size + h, w = src + + r = h / w + + # resize + if r > 1: + resize_height = th + resize_width = int(round(th / h * w)) + else: + resize_width = tw + resize_height = int(round(tw / w * h)) + + crop_top = int(round((th - resize_height) / 2.0)) + crop_left = int(round((tw - resize_width) / 2.0)) + + return (crop_top, crop_left), (crop_top + resize_height, crop_left + resize_width) + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg +def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0): + """ + Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and + Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4 + """ + std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True) + std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True) + # rescale the results from guidance (fixes overexposure) + noise_pred_rescaled = noise_cfg * (std_text / std_cfg) + # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images + noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg + return noise_cfg + +def unload_model(current_model): + if current_model is not None: + current_model.to(devices.cpu) + current_model = None + gc.collect() + devices.torch_gc() + return current_model + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs +def prepare_extra_step_kwargs(scheduler, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + +def randn_tensor( + shape: Union[Tuple, List], + generator: Optional[Union[List["torch.Generator"], "torch.Generator"]] = None, + device: Optional["torch.device"] = None, + dtype: Optional["torch.dtype"] = None, + layout: Optional["torch.layout"] = None, +): + """A helper function to create random tensors on the desired `device` with the desired `dtype`. When + passing a list of generators, you can seed each batch size individually. If CPU generators are passed, the tensor + is always created on the CPU. + """ + # device on which tensor is created defaults to device + rand_device = device + batch_size = shape[0] + + layout = layout or torch.strided + device = device or torch.device("cpu") + + if generator is not None: + gen_device_type = generator.device.type if not isinstance(generator, list) else generator[0].device.type + if gen_device_type != device.type and gen_device_type == "cpu": + rand_device = "cpu" + if device != "mps": + print( + f"The passed generator was created on 'cpu' even though a tensor on {device} was expected." + f" Tensors will be created on 'cpu' and then moved to {device}. Note that one can probably" + f" slighly speed up this function by passing a generator that was created on the {device} device." + ) + elif gen_device_type != device.type and gen_device_type == "cuda": + raise ValueError(f"Cannot generate a {device} tensor from a generator of type {gen_device_type}.") + + # make sure generator list of length 1 is treated like a non-list + if isinstance(generator, list) and len(generator) == 1: + generator = generator[0] + + if isinstance(generator, list): + shape = (1,) + shape[1:] + latents = [ + torch.randn(shape, generator=generator[i], device=rand_device, dtype=dtype, layout=layout) + for i in range(batch_size) + ] + latents = torch.cat(latents, dim=0).to(device) + else: + latents = torch.randn(shape, generator=generator, device=rand_device, dtype=dtype, layout=layout).to(device) + + return latents + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents +def prepare_latents_txt2img(vae_scale_factor, scheduler, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): + shape = ( + batch_size, + num_channels_latents, + int(height) // vae_scale_factor, + int(width) // vae_scale_factor, + ) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + latents = latents.to(device) + + # scale the initial noise by the standard deviation required by the scheduler + if hasattr(scheduler, 'init_noise_sigma'): + latents = latents * scheduler.init_noise_sigma + return latents + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: Optional[int] = None, + device: Optional[Union[str, torch.device]] = None, + timesteps: Optional[List[int]] = None, + sigmas: Optional[List[float]] = None, + **kwargs, +): + """ + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` + must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`List[int]`, *optional*): + Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, + `num_inference_steps` and `sigmas` must be `None`. + sigmas (`List[float]`, *optional*): + Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, + `num_inference_steps` and `timesteps` must be `None`. + + Returns: + `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None and sigmas is not None: + raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + elif sigmas is not None: + accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accept_sigmas: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" sigmas schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + +# Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_img2img.StableDiffusionXLImg2ImgPipeline.get_timesteps +def get_timesteps(scheduler, num_inference_steps, strength, device, denoising_start=None): + # get the original timestep using init_timestep + if denoising_start is None: + init_timestep = min(int(num_inference_steps * strength), num_inference_steps) + t_start = max(num_inference_steps - init_timestep, 0) + else: + t_start = 0 + + timesteps = scheduler.timesteps[t_start * scheduler.order :] + + # Strength is irrelevant if we directly request a timestep to start at; + # that is, strength is determined by the denoising_start instead. + if denoising_start is not None: + discrete_timestep_cutoff = int( + round( + scheduler.config.num_train_timesteps + - (denoising_start * scheduler.config.num_train_timesteps) + ) + ) + + num_inference_steps = (timesteps < discrete_timestep_cutoff).sum().item() + if scheduler.order == 2 and num_inference_steps % 2 == 0: + # if the scheduler is a 2nd order scheduler we might have to do +1 + # because `num_inference_steps` might be even given that every timestep + # (except the highest one) is duplicated. If `num_inference_steps` is even it would + # mean that we cut the timesteps in the middle of the denoising step + # (between 1st and 2nd derivative) which leads to incorrect results. By adding 1 + # we ensure that the denoising process always ends after the 2nd derivate step of the scheduler + num_inference_steps = num_inference_steps + 1 + + # because t_n+1 >= t_n, we slice the timesteps starting from the end + timesteps = timesteps[-num_inference_steps:] + return timesteps, num_inference_steps + + return timesteps, num_inference_steps - t_start + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents +def retrieve_latents( + encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample" +): + if hasattr(encoder_output, "latent_dist") and sample_mode == "sample": + return encoder_output.latent_dist.sample(generator) + elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax": + return encoder_output.latent_dist.mode() + elif hasattr(encoder_output, "latents"): + return encoder_output.latents + else: + raise AttributeError("Could not access latents of provided encoder_output") + +def _encode_vae_image(image: torch.Tensor, generator: torch.Generator): + #dtype = image.dtype + #image = image.float() + #self.vae_model.to(dtype=torch.float32) + + if isinstance(generator, list): + image_latents = [ + retrieve_latents(shared.vae_model.encode(image[i : i + 1]), generator=generator[i]) + for i in range(image.shape[0]) + ] + image_latents = torch.cat(image_latents, dim=0) + else: + image_latents = retrieve_latents(shared.vae_model.encode(image), generator=generator) + + #self.vae_model.to(dtype) + + #image_latents = image_latents.to(dtype) + image_latents = shared.vae_model.config.scaling_factor * image_latents + + return image_latents + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents +def prepare_latents_img2img(vae_scale_factor, scheduler, image, batch_size, num_channels_latents, height, width, dtype, device, generator, seeds, timestep): + shape = ( + batch_size, + num_channels_latents, + int(height) // vae_scale_factor, + int(width) // vae_scale_factor, + ) + generators = [rng.create_generator(seed) for seed in seeds] + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + image_latents = _encode_vae_image(image, generator=generators) + image_latents = image_latents.repeat(batch_size // image_latents.shape[0], 1, 1, 1) + + noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype).to(devices.device) + init_latents = scheduler.add_noise(image_latents, noise, timestep) + latents = init_latents.to(device=devices.device, dtype=dtype) + + return latents, noise, image_latents + +def guess_dit_model(state_dict): + if "state_dict" in state_dict: + state_dict = state_dict["state_dict"] + if "mlp_t5.0.weight" in state_dict: + return "hunyuan-original" + elif "text_embedder.linear_1.weight" in state_dict: + return "hunyuan" + else: + return "non supported dit" + +def convert_hunyuan_to_diffusers(state_dict): + if "state_dict" in state_dict: + state_dict = state_dict["state_dict"] + # input_size -> sample_size, text_dim -> cross_attention_dim + num_layers = 40 + for i in range(num_layers): + # attn1 + # Wkqv -> to_q, to_k, to_v + q, k, v = torch.chunk(state_dict[f"blocks.{i}.attn1.Wqkv.weight"], 3, dim=0) + q_bias, k_bias, v_bias = torch.chunk(state_dict[f"blocks.{i}.attn1.Wqkv.bias"], 3, dim=0) + state_dict[f"blocks.{i}.attn1.to_q.weight"] = q + state_dict[f"blocks.{i}.attn1.to_q.bias"] = q_bias + state_dict[f"blocks.{i}.attn1.to_k.weight"] = k + state_dict[f"blocks.{i}.attn1.to_k.bias"] = k_bias + state_dict[f"blocks.{i}.attn1.to_v.weight"] = v + state_dict[f"blocks.{i}.attn1.to_v.bias"] = v_bias + state_dict.pop(f"blocks.{i}.attn1.Wqkv.weight") + state_dict.pop(f"blocks.{i}.attn1.Wqkv.bias") + + # q_norm, k_norm -> norm_q, norm_k + state_dict[f"blocks.{i}.attn1.norm_q.weight"] = state_dict[f"blocks.{i}.attn1.q_norm.weight"] + state_dict[f"blocks.{i}.attn1.norm_q.bias"] = state_dict[f"blocks.{i}.attn1.q_norm.bias"] + state_dict[f"blocks.{i}.attn1.norm_k.weight"] = state_dict[f"blocks.{i}.attn1.k_norm.weight"] + state_dict[f"blocks.{i}.attn1.norm_k.bias"] = state_dict[f"blocks.{i}.attn1.k_norm.bias"] + + state_dict.pop(f"blocks.{i}.attn1.q_norm.weight") + state_dict.pop(f"blocks.{i}.attn1.q_norm.bias") + state_dict.pop(f"blocks.{i}.attn1.k_norm.weight") + state_dict.pop(f"blocks.{i}.attn1.k_norm.bias") + + # out_proj -> to_out + state_dict[f"blocks.{i}.attn1.to_out.0.weight"] = state_dict[f"blocks.{i}.attn1.out_proj.weight"] + state_dict[f"blocks.{i}.attn1.to_out.0.bias"] = state_dict[f"blocks.{i}.attn1.out_proj.bias"] + state_dict.pop(f"blocks.{i}.attn1.out_proj.weight") + state_dict.pop(f"blocks.{i}.attn1.out_proj.bias") + + # attn2 + # kq_proj -> to_k, to_v + k, v = torch.chunk(state_dict[f"blocks.{i}.attn2.kv_proj.weight"], 2, dim=0) + k_bias, v_bias = torch.chunk(state_dict[f"blocks.{i}.attn2.kv_proj.bias"], 2, dim=0) + state_dict[f"blocks.{i}.attn2.to_k.weight"] = k + state_dict[f"blocks.{i}.attn2.to_k.bias"] = k_bias + state_dict[f"blocks.{i}.attn2.to_v.weight"] = v + state_dict[f"blocks.{i}.attn2.to_v.bias"] = v_bias + state_dict.pop(f"blocks.{i}.attn2.kv_proj.weight") + state_dict.pop(f"blocks.{i}.attn2.kv_proj.bias") + + # q_proj -> to_q + state_dict[f"blocks.{i}.attn2.to_q.weight"] = state_dict[f"blocks.{i}.attn2.q_proj.weight"] + state_dict[f"blocks.{i}.attn2.to_q.bias"] = state_dict[f"blocks.{i}.attn2.q_proj.bias"] + state_dict.pop(f"blocks.{i}.attn2.q_proj.weight") + state_dict.pop(f"blocks.{i}.attn2.q_proj.bias") + + # q_norm, k_norm -> norm_q, norm_k + state_dict[f"blocks.{i}.attn2.norm_q.weight"] = state_dict[f"blocks.{i}.attn2.q_norm.weight"] + state_dict[f"blocks.{i}.attn2.norm_q.bias"] = state_dict[f"blocks.{i}.attn2.q_norm.bias"] + state_dict[f"blocks.{i}.attn2.norm_k.weight"] = state_dict[f"blocks.{i}.attn2.k_norm.weight"] + state_dict[f"blocks.{i}.attn2.norm_k.bias"] = state_dict[f"blocks.{i}.attn2.k_norm.bias"] + + state_dict.pop(f"blocks.{i}.attn2.q_norm.weight") + state_dict.pop(f"blocks.{i}.attn2.q_norm.bias") + state_dict.pop(f"blocks.{i}.attn2.k_norm.weight") + state_dict.pop(f"blocks.{i}.attn2.k_norm.bias") + + # out_proj -> to_out + state_dict[f"blocks.{i}.attn2.to_out.0.weight"] = state_dict[f"blocks.{i}.attn2.out_proj.weight"] + state_dict[f"blocks.{i}.attn2.to_out.0.bias"] = state_dict[f"blocks.{i}.attn2.out_proj.bias"] + state_dict.pop(f"blocks.{i}.attn2.out_proj.weight") + state_dict.pop(f"blocks.{i}.attn2.out_proj.bias") + + # switch norm 2 and norm 3 + norm2_weight = state_dict[f"blocks.{i}.norm2.weight"] + norm2_bias = state_dict[f"blocks.{i}.norm2.bias"] + state_dict[f"blocks.{i}.norm2.weight"] = state_dict[f"blocks.{i}.norm3.weight"] + state_dict[f"blocks.{i}.norm2.bias"] = state_dict[f"blocks.{i}.norm3.bias"] + state_dict[f"blocks.{i}.norm3.weight"] = norm2_weight + state_dict[f"blocks.{i}.norm3.bias"] = norm2_bias + + # norm1 -> norm1.norm + # default_modulation.1 -> norm1.linear + state_dict[f"blocks.{i}.norm1.norm.weight"] = state_dict[f"blocks.{i}.norm1.weight"] + state_dict[f"blocks.{i}.norm1.norm.bias"] = state_dict[f"blocks.{i}.norm1.bias"] + state_dict[f"blocks.{i}.norm1.linear.weight"] = state_dict[f"blocks.{i}.default_modulation.1.weight"] + state_dict[f"blocks.{i}.norm1.linear.bias"] = state_dict[f"blocks.{i}.default_modulation.1.bias"] + state_dict.pop(f"blocks.{i}.norm1.weight") + state_dict.pop(f"blocks.{i}.norm1.bias") + state_dict.pop(f"blocks.{i}.default_modulation.1.weight") + state_dict.pop(f"blocks.{i}.default_modulation.1.bias") + + # mlp.fc1 -> ff.net.0, mlp.fc2 -> ff.net.2 + state_dict[f"blocks.{i}.ff.net.0.proj.weight"] = state_dict[f"blocks.{i}.mlp.fc1.weight"] + state_dict[f"blocks.{i}.ff.net.0.proj.bias"] = state_dict[f"blocks.{i}.mlp.fc1.bias"] + state_dict[f"blocks.{i}.ff.net.2.weight"] = state_dict[f"blocks.{i}.mlp.fc2.weight"] + state_dict[f"blocks.{i}.ff.net.2.bias"] = state_dict[f"blocks.{i}.mlp.fc2.bias"] + state_dict.pop(f"blocks.{i}.mlp.fc1.weight") + state_dict.pop(f"blocks.{i}.mlp.fc1.bias") + state_dict.pop(f"blocks.{i}.mlp.fc2.weight") + state_dict.pop(f"blocks.{i}.mlp.fc2.bias") + + # pooler -> time_extra_emb + state_dict["time_extra_emb.pooler.positional_embedding"] = state_dict["pooler.positional_embedding"] + state_dict["time_extra_emb.pooler.k_proj.weight"] = state_dict["pooler.k_proj.weight"] + state_dict["time_extra_emb.pooler.k_proj.bias"] = state_dict["pooler.k_proj.bias"] + state_dict["time_extra_emb.pooler.q_proj.weight"] = state_dict["pooler.q_proj.weight"] + state_dict["time_extra_emb.pooler.q_proj.bias"] = state_dict["pooler.q_proj.bias"] + state_dict["time_extra_emb.pooler.v_proj.weight"] = state_dict["pooler.v_proj.weight"] + state_dict["time_extra_emb.pooler.v_proj.bias"] = state_dict["pooler.v_proj.bias"] + state_dict["time_extra_emb.pooler.c_proj.weight"] = state_dict["pooler.c_proj.weight"] + state_dict["time_extra_emb.pooler.c_proj.bias"] = state_dict["pooler.c_proj.bias"] + state_dict.pop("pooler.k_proj.weight") + state_dict.pop("pooler.k_proj.bias") + state_dict.pop("pooler.q_proj.weight") + state_dict.pop("pooler.q_proj.bias") + state_dict.pop("pooler.v_proj.weight") + state_dict.pop("pooler.v_proj.bias") + state_dict.pop("pooler.c_proj.weight") + state_dict.pop("pooler.c_proj.bias") + state_dict.pop("pooler.positional_embedding") + + # t_embedder -> time_embedding (`TimestepEmbedding`) + state_dict["time_extra_emb.timestep_embedder.linear_1.bias"] = state_dict["t_embedder.mlp.0.bias"] + state_dict["time_extra_emb.timestep_embedder.linear_1.weight"] = state_dict["t_embedder.mlp.0.weight"] + state_dict["time_extra_emb.timestep_embedder.linear_2.bias"] = state_dict["t_embedder.mlp.2.bias"] + state_dict["time_extra_emb.timestep_embedder.linear_2.weight"] = state_dict["t_embedder.mlp.2.weight"] + + state_dict.pop("t_embedder.mlp.0.bias") + state_dict.pop("t_embedder.mlp.0.weight") + state_dict.pop("t_embedder.mlp.2.bias") + state_dict.pop("t_embedder.mlp.2.weight") + + # x_embedder -> pos_embd (`PatchEmbed`) + state_dict["pos_embed.proj.weight"] = state_dict["x_embedder.proj.weight"] + state_dict["pos_embed.proj.bias"] = state_dict["x_embedder.proj.bias"] + state_dict.pop("x_embedder.proj.weight") + state_dict.pop("x_embedder.proj.bias") + + # mlp_t5 -> text_embedder + state_dict["text_embedder.linear_1.bias"] = state_dict["mlp_t5.0.bias"] + state_dict["text_embedder.linear_1.weight"] = state_dict["mlp_t5.0.weight"] + state_dict["text_embedder.linear_2.bias"] = state_dict["mlp_t5.2.bias"] + state_dict["text_embedder.linear_2.weight"] = state_dict["mlp_t5.2.weight"] + state_dict.pop("mlp_t5.0.bias") + state_dict.pop("mlp_t5.0.weight") + state_dict.pop("mlp_t5.2.bias") + state_dict.pop("mlp_t5.2.weight") + + # extra_embedder -> extra_embedder + state_dict["time_extra_emb.extra_embedder.linear_1.bias"] = state_dict["extra_embedder.0.bias"] + state_dict["time_extra_emb.extra_embedder.linear_1.weight"] = state_dict["extra_embedder.0.weight"] + state_dict["time_extra_emb.extra_embedder.linear_2.bias"] = state_dict["extra_embedder.2.bias"] + state_dict["time_extra_emb.extra_embedder.linear_2.weight"] = state_dict["extra_embedder.2.weight"] + state_dict.pop("extra_embedder.0.bias") + state_dict.pop("extra_embedder.0.weight") + state_dict.pop("extra_embedder.2.bias") + state_dict.pop("extra_embedder.2.weight") + + # model.final_adaLN_modulation.1 -> norm_out.linear + def swap_scale_shift(weight): + shift, scale = weight.chunk(2, dim=0) + new_weight = torch.cat([scale, shift], dim=0) + return new_weight + + state_dict["norm_out.linear.weight"] = swap_scale_shift(state_dict["final_layer.adaLN_modulation.1.weight"]) + state_dict["norm_out.linear.bias"] = swap_scale_shift(state_dict["final_layer.adaLN_modulation.1.bias"]) + state_dict.pop("final_layer.adaLN_modulation.1.weight") + state_dict.pop("final_layer.adaLN_modulation.1.bias") + + # final_linear -> proj_out + state_dict["proj_out.weight"] = state_dict["final_layer.linear.weight"] + state_dict["proj_out.bias"] = state_dict["final_layer.linear.bias"] + state_dict.pop("final_layer.linear.weight") + state_dict.pop("final_layer.linear.bias") + return state_dict \ No newline at end of file diff --git a/sd-webui-hunyuan-dit/requirements.txt b/sd-webui-hunyuan-dit/requirements.txt new file mode 100755 index 0000000..318434f --- /dev/null +++ b/sd-webui-hunyuan-dit/requirements.txt @@ -0,0 +1,2 @@ +transformers==4.40.1 +git+https://github.com/huggingface/diffusers.git \ No newline at end of file diff --git a/sd-webui-hunyuan-dit/scripts/hunyuandit.py b/sd-webui-hunyuan-dit/scripts/hunyuandit.py new file mode 100755 index 0000000..8313468 --- /dev/null +++ b/sd-webui-hunyuan-dit/scripts/hunyuandit.py @@ -0,0 +1,883 @@ + +import torch +import gradio as gr +from transformers import T5EncoderModel, MT5Tokenizer, BertModel, BertTokenizer +from diffusers import AutoencoderKL, DDPMScheduler +from modules import prompt_parser, shared, rng, devices, processing, scripts, masking, sd_models, sd_samplers_common, images, paths, face_restoration, script_callbacks +from modules.sd_hijack import model_hijack +from modules.timer import Timer +from hunyuan_utils.utils import dit_sampler_dict, hunyuan_transformer_config_v12, retrieve_timesteps, get_timesteps, get_resize_crop_region_for_grid, unload_model, prepare_extra_step_kwargs, prepare_latents_txt2img, prepare_latents_img2img, guess_dit_model, convert_hunyuan_to_diffusers +from hunyuan_utils import sd_hijack_clip_diffusers, diffusers_learned_conditioning +import os +import numpy as np +from PIL import Image, ImageOps +import cv2 +import hashlib + +shared.clip_l_model = None +shared.mt5_model = None +shared.vae_model = None + +def sample_txt2img(self, conditioning, unconditional_conditioning, seeds): + # define sampler"" + self.sampler = dit_sampler_dict.get((self.sampler_name+" "+self.scheduler.replace("Automatic","")).strip(),DDPMScheduler()).from_pretrained(shared.opts.Hunyuan_model_path,subfolder="scheduler") + # reuse webui generated conditionings + _, tensor = prompt_parser.reconstruct_multicond_batch(conditioning, 0) + prompt_embeds = tensor["crossattn"] + prompt_attention_mask = tensor["mask"] + prompt_embeds_2 = tensor["crossattn_2"] + prompt_attention_mask_2 = tensor["mask_2"] + uncond = prompt_parser.reconstruct_cond_batch(unconditional_conditioning, 0) + negative_prompt_embeds = uncond["crossattn"] + negative_prompt_attention_mask = uncond["mask"] + negative_prompt_embeds_2 = uncond["crossattn_2"] + negative_prompt_attention_mask_2 = uncond["mask_2"] + # 4. Prepare timesteps + self.sampler.set_timesteps(self.steps, device=devices.device) + timesteps = self.sampler.timesteps + shared.state.sampling_steps = len(timesteps) + # 5. Prepare latents. + latent_channels = self.sd_model.config.in_channels + generators = [rng.create_generator(seed) for seed in seeds] + latents = prepare_latents_txt2img( + 2 ** (len(shared.vae_model.config.block_out_channels) - 1), + self.sampler, + self.batch_size, + latent_channels, + self.height, + self.width, + prompt_embeds.dtype, + torch.device("cuda") if shared.opts.randn_source == "GPU" else torch.device("cpu"), + generators, + None + ).to(devices.device) + # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = prepare_extra_step_kwargs(self.sampler, generators, 0.0) + + # 7 create image_rotary_emb, style embedding & time ids + grid_height = self.height // 8 // self.sd_model.config.patch_size + grid_width = self.width // 8 // self.sd_model.config.patch_size + base_size = 512 // 8 // self.sd_model.config.patch_size + grid_crops_coords = get_resize_crop_region_for_grid((grid_height, grid_width), base_size) + from diffusers.models.embeddings import get_2d_rotary_pos_embed + image_rotary_emb = get_2d_rotary_pos_embed( + self.sd_model.inner_dim // self.sd_model.num_heads, grid_crops_coords, (grid_height, grid_width) + ) + style = torch.tensor([0], device=devices.device) + + target_size = (self.height, self.width) + add_time_ids = list((1024, 1024) + target_size + (0,0)) + add_time_ids = torch.tensor([add_time_ids], dtype=prompt_embeds.dtype) + if self.cfg_scale > 1: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) + prompt_attention_mask = torch.cat([negative_prompt_attention_mask, prompt_attention_mask]) + prompt_embeds_2 = torch.cat([negative_prompt_embeds_2, prompt_embeds_2]) + prompt_attention_mask_2 = torch.cat([negative_prompt_attention_mask_2, prompt_attention_mask_2]) + add_time_ids = torch.cat([add_time_ids] * 2, dim=0) + style = torch.cat([style] * 2, dim=0) + add_time_ids = add_time_ids.to(dtype=prompt_embeds.dtype, device=devices.device).repeat( + self.batch_size, 1 + ) + style = style.to(device=devices.device).repeat(self.batch_size) + for i, t in enumerate(timesteps): + if shared.state.interrupted or shared.state.skipped: + raise sd_samplers_common.InterruptedException + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if self.cfg_scale > 1.0 else latents + latent_model_input = self.sampler.scale_model_input(latent_model_input, t) + + # expand scalar t to 1-D tensor to match the 1st dim of latent_model_input + t_expand = torch.tensor([t] * latent_model_input.shape[0], device=devices.device).to( + dtype=latent_model_input.dtype + ) + # predict the noise residual + noise_pred = self.sd_model( + latent_model_input, + t_expand, + encoder_hidden_states=prompt_embeds, + text_embedding_mask=prompt_attention_mask, + encoder_hidden_states_t5=prompt_embeds_2, + text_embedding_mask_t5=prompt_attention_mask_2, + image_meta_size=add_time_ids, + style=style, + image_rotary_emb=image_rotary_emb, + return_dict=False, + )[0] + + noise_pred, _ = noise_pred.chunk(2, dim=1) + + # perform guidance + if self.cfg_scale > 1.0: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + self.cfg_scale * (noise_pred_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.sampler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] + # update process + shared.state.sampling_step += 1 + shared.total_tqdm.update() + return latents.to(devices.dtype) + +def sample_img2img(self, conditioning, unconditional_conditioning, seeds): + # define sampler + self.sampler = dit_sampler_dict.get((self.sampler_name+" "+self.scheduler.replace("Automatic","")).strip(),DDPMScheduler()).from_pretrained(shared.opts.Hunyuan_model_path,subfolder="scheduler") + # reuse webui generated conditionings + _, tensor = prompt_parser.reconstruct_multicond_batch(conditioning, 0) + prompt_embeds = tensor["crossattn"] + prompt_attention_mask = tensor["mask"] + prompt_embeds_2 = tensor["crossattn_2"] + prompt_attention_mask_2 = tensor["mask_2"] + uncond = prompt_parser.reconstruct_cond_batch(unconditional_conditioning, 0) + negative_prompt_embeds = uncond["crossattn"] + negative_prompt_attention_mask = uncond["mask"] + negative_prompt_embeds_2 = uncond["crossattn_2"] + negative_prompt_attention_mask_2 = uncond["mask_2"] + # 4. Prepare timesteps + timesteps, num_inference_steps = retrieve_timesteps( + self.sampler, self.steps, devices.device, None, None + ) + timesteps, num_inference_steps = get_timesteps( + self.sampler, + num_inference_steps, + self.denoising_strength, + devices.device, + denoising_start=None, + ) + latent_timestep = timesteps[:1].repeat(self.batch_size) + shared.state.sampling_steps = len(timesteps) + # 5. Prepare latents. + latent_channels = self.sd_model.config.in_channels + latents_outputs = prepare_latents_img2img( + 2 ** (len(shared.vae_model.config.block_out_channels) - 1), + self.sampler, + self.image, + self.batch_size, + latent_channels, + self.height, + self.width, + prompt_embeds.dtype, + torch.device("cuda") if shared.opts.randn_source == "GPU" else torch.device("cpu"), + None, + seeds, + latent_timestep + ) + latents, noise, image_latents = latents_outputs + self.init_latent = latents + # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = prepare_extra_step_kwargs(self.sampler, None, 0.0) + + # 7 create image_rotary_emb, style embedding & time ids + grid_height = self.height // 8 // self.sd_model.config.patch_size + grid_width = self.width // 8 // self.sd_model.config.patch_size + base_size = 512 // 8 // self.sd_model.config.patch_size + grid_crops_coords = get_resize_crop_region_for_grid((grid_height, grid_width), base_size) + from diffusers.models.embeddings import get_2d_rotary_pos_embed + image_rotary_emb = get_2d_rotary_pos_embed( + self.sd_model.inner_dim // self.sd_model.num_heads, grid_crops_coords, (grid_height, grid_width) + ) + style = torch.tensor([0], device=devices.device) + + target_size = (self.height, self.width) + add_time_ids = list((1024, 1024) + target_size + (0,0)) + add_time_ids = torch.tensor([add_time_ids], dtype=prompt_embeds.dtype) + if self.cfg_scale > 1: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) + prompt_attention_mask = torch.cat([negative_prompt_attention_mask, prompt_attention_mask]) + prompt_embeds_2 = torch.cat([negative_prompt_embeds_2, prompt_embeds_2]) + prompt_attention_mask_2 = torch.cat([negative_prompt_attention_mask_2, prompt_attention_mask_2]) + add_time_ids = torch.cat([add_time_ids] * 2, dim=0) + style = torch.cat([style] * 2, dim=0) + add_time_ids = add_time_ids.to(dtype=prompt_embeds.dtype, device=devices.device).repeat( + self.batch_size, 1 + ) + style = style.to(device=devices.device).repeat(self.batch_size) + for i, t in enumerate(timesteps): + if shared.state.interrupted or shared.state.skipped: + raise sd_samplers_common.InterruptedException + # expand the latents if we are doing classifier free guidance + latent_model_input = latents + latent_model_input = self.sampler.scale_model_input(latent_model_input, t) + + # expand scalar t to 1-D tensor to match the 1st dim of latent_model_input + t_expand = torch.tensor([t] * latent_model_input.shape[0], device=devices.device).to( + dtype=latent_model_input.dtype + ) + + # predict the noise residual + noise_pred = self.sd_model( + latent_model_input, + t_expand, + encoder_hidden_states=prompt_embeds, + text_embedding_mask=prompt_attention_mask, + encoder_hidden_states_t5=prompt_embeds_2, + text_embedding_mask_t5=prompt_attention_mask_2, + image_meta_size=add_time_ids, + style=style, + image_rotary_emb=image_rotary_emb, + return_dict=False, + )[0] + + noise_pred, _ = noise_pred.chunk(2, dim=1) + + # perform guidance + if self.cfg_scale > 1.0: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + self.cfg_scale * (noise_pred_text - noise_pred_uncond) + # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.sampler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] + if latent_channels == 4 and self.image_mask is not None: + latents = self.mask * self.init_latent + self.nmask * latents + # update process + shared.state.sampling_step += 1 + shared.total_tqdm.update() + + return latents.to(devices.dtype) + +def init_img2img(self, all_prompts, all_seeds, all_subseeds): + self.extra_generation_params["Denoising strength"] = self.denoising_strength + + self.image_cfg_scale: float = self.image_cfg_scale if shared.sd_model.cond_stage_key == "edit" else None + + #self.sampler = sd_samplers.create_sampler(self.sampler_name, self.sd_model) + crop_region = None + + image_mask = self.image_mask + + if image_mask is not None: + # image_mask is passed in as RGBA by Gradio to support alpha masks, + # but we still want to support binary masks. + image_mask = processing.create_binary_mask(image_mask, round=self.mask_round) + + if self.inpainting_mask_invert: + image_mask = ImageOps.invert(image_mask) + self.extra_generation_params["Mask mode"] = "Inpaint not masked" + + if self.mask_blur_x > 0: + np_mask = np.array(image_mask) + kernel_size = 2 * int(2.5 * self.mask_blur_x + 0.5) + 1 + np_mask = cv2.GaussianBlur(np_mask, (kernel_size, 1), self.mask_blur_x) + image_mask = Image.fromarray(np_mask) + + if self.mask_blur_y > 0: + np_mask = np.array(image_mask) + kernel_size = 2 * int(2.5 * self.mask_blur_y + 0.5) + 1 + np_mask = cv2.GaussianBlur(np_mask, (1, kernel_size), self.mask_blur_y) + image_mask = Image.fromarray(np_mask) + + if self.mask_blur_x > 0 or self.mask_blur_y > 0: + self.extra_generation_params["Mask blur"] = self.mask_blur + + if self.inpaint_full_res: + self.mask_for_overlay = image_mask + mask = image_mask.convert('L') + crop_region = masking.get_crop_region_v2(mask, self.inpaint_full_res_padding) + if crop_region: + crop_region = masking.expand_crop_region(crop_region, self.width, self.height, mask.width, mask.height) + x1, y1, x2, y2 = crop_region + mask = mask.crop(crop_region) + image_mask = images.resize_image(2, mask, self.width, self.height) + self.paste_to = (x1, y1, x2-x1, y2-y1) + self.extra_generation_params["Inpaint area"] = "Only masked" + self.extra_generation_params["Masked area padding"] = self.inpaint_full_res_padding + else: + crop_region = None + image_mask = None + self.mask_for_overlay = None + self.inpaint_full_res = False + massage = 'Unable to perform "Inpaint Only mask" because mask is blank, switch to img2img mode.' + model_hijack.comments.append(massage) + else: + image_mask = images.resize_image(self.resize_mode, image_mask, self.width, self.height) + np_mask = np.array(image_mask) + np_mask = np.clip((np_mask.astype(np.float32)) * 2, 0, 255).astype(np.uint8) + self.mask_for_overlay = Image.fromarray(np_mask) + + self.overlay_images = [] + + latent_mask = self.latent_mask if self.latent_mask is not None else image_mask + + add_color_corrections = shared.opts.img2img_color_correction and self.color_corrections is None + if add_color_corrections: + self.color_corrections = [] + imgs = [] + for img in self.init_images: + + # Save init image + if shared.opts.save_init_img: + self.init_img_hash = hashlib.md5(img.tobytes()).hexdigest() + images.save_image(img, path=shared.opts.outdir_init_images, basename=None, forced_filename=self.init_img_hash, save_to_dirs=False, existing_info=img.info) + + image = images.flatten(img, shared.opts.img2img_background_color) + + if crop_region is None and self.resize_mode != 3: + image = images.resize_image(self.resize_mode, image, self.width, self.height) + + if image_mask is not None: + if self.mask_for_overlay.size != (image.width, image.height): + self.mask_for_overlay = images.resize_image(self.resize_mode, self.mask_for_overlay, image.width, image.height) + image_masked = Image.new('RGBa', (image.width, image.height)) + image_masked.paste(image.convert("RGBA").convert("RGBa"), mask=ImageOps.invert(self.mask_for_overlay.convert('L'))) + + self.overlay_images.append(image_masked.convert('RGBA')) + + # crop_region is not None if we are doing inpaint full res + if crop_region is not None: + image = image.crop(crop_region) + image = images.resize_image(2, image, self.width, self.height) + + if image_mask is not None: + if self.inpainting_fill != 1: + image = masking.fill(image, latent_mask) + + if self.inpainting_fill == 0: + self.extra_generation_params["Masked content"] = 'fill' + + if add_color_corrections: + self.color_corrections.append(processing.setup_color_correction(image)) + + image = np.array(image).astype(np.float32) / 255.0 + image = np.moveaxis(image, 2, 0) + + imgs.append(image) + + if len(imgs) == 1: + batch_images = np.expand_dims(imgs[0], axis=0).repeat(self.batch_size, axis=0) + if self.overlay_images is not None: + self.overlay_images = self.overlay_images * self.batch_size + + if self.color_corrections is not None and len(self.color_corrections) == 1: + self.color_corrections = self.color_corrections * self.batch_size + + elif len(imgs) <= self.batch_size: + self.batch_size = len(imgs) + batch_images = np.array(imgs) + else: + raise RuntimeError(f"bad number of images passed: {len(imgs)}; expecting {self.batch_size} or less") + + image = torch.from_numpy(batch_images) + self.image = image.to(shared.device, dtype=devices.dtype_vae) + +def process_images_inner_hunyuan(p: processing.StableDiffusionProcessing) -> processing.Processed: + """this is the main loop that both txt2img and img2img use; it calls func_init once inside all the scopes and func_sample once per batch""" + + if isinstance(p.prompt, list): + assert(len(p.prompt) > 0) + else: + assert p.prompt is not None + + devices.torch_gc() + + seed = processing.get_fixed_seed(p.seed) + subseed = processing.get_fixed_seed(p.subseed) + + if p.restore_faces is None: + p.restore_faces = shared.opts.face_restoration + + if p.tiling is None: + p.tiling = shared.opts.tiling + + # disable refiner + ''' + if p.refiner_checkpoint not in (None, "", "None", "none"): + p.refiner_checkpoint_info = sd_models.get_closet_checkpoint_match(p.refiner_checkpoint) + if p.refiner_checkpoint_info is None: + raise Exception(f'Could not find checkpoint with name {p.refiner_checkpoint}') + ''' + p.sd_model_name = shared.sd_model.sd_checkpoint_info.name_for_extra + p.sd_model_hash = shared.sd_model.sd_model_hash + # disable stable diffusion vae + ''' + p.sd_vae_name = sd_vae.get_loaded_vae_name() + p.sd_vae_hash = sd_vae.get_loaded_vae_hash() + ''' + model_hijack.apply_circular(p.tiling) + model_hijack.clear_comments() + + p.setup_prompts() + + if isinstance(seed, list): + p.all_seeds = seed + else: + p.all_seeds = [int(seed) + (x if p.subseed_strength == 0 else 0) for x in range(len(p.all_prompts))] + + if isinstance(subseed, list): + p.all_subseeds = subseed + else: + p.all_subseeds = [int(subseed) + x for x in range(len(p.all_prompts))] + + if os.path.exists(shared.cmd_opts.embeddings_dir) and not p.do_not_reload_embeddings: + model_hijack.embedding_db.load_textual_inversion_embeddings() + + if p.scripts is not None: + p.scripts.process(p) + + infotexts = [] + output_images = [] + with torch.no_grad(): + with devices.autocast(): + p.init(p.all_prompts, p.all_seeds, p.all_subseeds) + + # disable stable diffusion vae + ''' + # for OSX, loading the model during sampling changes the generated picture, so it is loaded here + if shared.opts.live_previews_enable and opts.show_progress_type == "Approx NN": + sd_vae_approx.model() + + sd_unet.apply_unet() + ''' + if shared.state.job_count == -1: + shared.state.job_count = p.n_iter + + for n in range(p.n_iter): + p.iteration = n + + if shared.state.skipped: + shared.state.skipped = False + + if shared.state.interrupted or shared.state.stopping_generation: + break + + sd_models.reload_model_weights() # model can be changed for example by refiner + + p.prompts = p.all_prompts[n * p.batch_size:(n + 1) * p.batch_size] + p.negative_prompts = p.all_negative_prompts[n * p.batch_size:(n + 1) * p.batch_size] + p.seeds = p.all_seeds[n * p.batch_size:(n + 1) * p.batch_size] + p.subseeds = p.all_subseeds[n * p.batch_size:(n + 1) * p.batch_size] + + # disable webui rng for stable diffusion + #p.rng = rng.ImageRNG((opt_C, p.height // opt_f, p.width // opt_f), p.seeds, subseeds=p.subseeds, subseed_strength=p.subseed_strength, seed_resize_from_h=p.seed_resize_from_h, seed_resize_from_w=p.seed_resize_from_w) + + if p.scripts is not None: + p.scripts.before_process_batch(p, batch_number=n, prompts=p.prompts, seeds=p.seeds, subseeds=p.subseeds) + + if len(p.prompts) == 0: + break + # disabled sd webui type loras + ''' + p.parse_extra_network_prompts() + + if not p.disable_extra_networks: + with devices.autocast(): + extra_networks.activate(p, p.extra_network_data) + ''' + if p.scripts is not None: + p.scripts.process_batch(p, batch_number=n, prompts=p.prompts, seeds=p.seeds, subseeds=p.subseeds) + + p.setup_conds() + + # p.extra_generation_params.update(model_hijack.extra_generation_params) + + # params.txt should be saved after scripts.process_batch, since the + # infotext could be modified by that callback + # Example: a wildcard processed by process_batch sets an extra model + # strength, which is saved as "Model Strength: 1.0" in the infotext + if n == 0 and not shared.cmd_opts.no_prompt_history: + with open(os.path.join(paths.data_path, "params.txt"), "w", encoding="utf8") as file: + processed = processing.Processed(p, []) + file.write(processed.infotext(p, 0)) + + for comment in model_hijack.comments: + p.comment(comment) + + if p.n_iter > 1: + shared.state.job = f"Batch {n+1} out of {p.n_iter}" + + sd_models.apply_alpha_schedule_override(p.sd_model, p) + + with devices.without_autocast() if devices.unet_needs_upcast else devices.autocast(): + samples_ddim = p.sample(conditioning=p.c, unconditional_conditioning=p.uc, seeds=p.seeds) + + if p.scripts is not None: + ps = scripts.PostSampleArgs(samples_ddim) + p.scripts.post_sample(p, ps) + samples_ddim = ps.samples + + if getattr(samples_ddim, 'already_decoded', False): + x_samples_ddim = samples_ddim + else: + if shared.opts.sd_vae_decode_method != 'Full': + p.extra_generation_params['VAE Decoder'] = shared.opts.sd_vae_decode_method + x_samples_ddim = shared.vae_model.decode(samples_ddim / shared.vae_model.config.scaling_factor, return_dict=False)[0] + + x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0) + x_samples_ddim = x_samples_ddim.cpu().permute(0, 2, 3, 1).float().numpy() + + del samples_ddim + + devices.torch_gc() + + shared.state.nextjob() + + if p.scripts is not None: + p.scripts.postprocess_batch(p, x_samples_ddim, batch_number=n) + + p.prompts = p.all_prompts[n * p.batch_size:(n + 1) * p.batch_size] + p.negative_prompts = p.all_negative_prompts[n * p.batch_size:(n + 1) * p.batch_size] + + batch_params = scripts.PostprocessBatchListArgs(list(x_samples_ddim)) + p.scripts.postprocess_batch_list(p, batch_params, batch_number=n) + x_samples_ddim = batch_params.images + + def infotext(index=0, use_main_prompt=False): + return processing.create_infotext(p, p.prompts, p.seeds, p.subseeds, use_main_prompt=use_main_prompt, index=index, all_negative_prompts=p.negative_prompts) + + save_samples = p.save_samples() + + for i, x_sample in enumerate(x_samples_ddim): + p.batch_index = i + + x_sample = 255. * x_sample + x_sample = x_sample.astype(np.uint8) + + if p.restore_faces: + if save_samples and shared.opts.save_images_before_face_restoration: + images.save_image(Image.fromarray(x_sample), p.outpath_samples, "", p.seeds[i], p.prompts[i], shared.opts.samples_format, info=infotext(i), p=p, suffix="-before-face-restoration") + + devices.torch_gc() + + x_sample = face_restoration.restore_faces(x_sample) + devices.torch_gc() + + image = Image.fromarray(x_sample) + + if p.scripts is not None: + pp = scripts.PostprocessImageArgs(image) + p.scripts.postprocess_image(p, pp) + image = pp.image + + mask_for_overlay = getattr(p, "mask_for_overlay", None) + + if not shared.opts.overlay_inpaint: + overlay_image = None + elif getattr(p, "overlay_images", None) is not None and i < len(p.overlay_images): + overlay_image = p.overlay_images[i] + else: + overlay_image = None + + if p.scripts is not None: + ppmo = scripts.PostProcessMaskOverlayArgs(i, mask_for_overlay, overlay_image) + p.scripts.postprocess_maskoverlay(p, ppmo) + mask_for_overlay, overlay_image = ppmo.mask_for_overlay, ppmo.overlay_image + + if p.color_corrections is not None and i < len(p.color_corrections): + if save_samples and shared.opts.save_images_before_color_correction: + image_without_cc, _ = processing.apply_overlay(image, p.paste_to, overlay_image) + images.save_image(image_without_cc, p.outpath_samples, "", p.seeds[i], p.prompts[i], shared.opts.samples_format, info=infotext(i), p=p, suffix="-before-color-correction") + image = processing.apply_color_correction(p.color_corrections[i], image) + + # If the intention is to show the output from the model + # that is being composited over the original image, + # we need to keep the original image around + # and use it in the composite step. + image, original_denoised_image = processing.apply_overlay(image, p.paste_to, overlay_image) + + if p.scripts is not None: + pp = scripts.PostprocessImageArgs(image) + p.scripts.postprocess_image_after_composite(p, pp) + image = pp.image + + if save_samples: + images.save_image(image, p.outpath_samples, "", p.seeds[i], p.prompts[i], shared.opts.samples_format, info=infotext(i), p=p) + + text = infotext(i) + infotexts.append(text) + if shared.opts.enable_pnginfo: + image.info["parameters"] = text + output_images.append(image) + + if mask_for_overlay is not None: + if shared.opts.return_mask or shared.opts.save_mask: + image_mask = mask_for_overlay.convert('RGB') + if save_samples and shared.opts.save_mask: + images.save_image(image_mask, p.outpath_samples, "", p.seeds[i], p.prompts[i], shared.opts.samples_format, info=infotext(i), p=p, suffix="-mask") + if shared.opts.return_mask: + output_images.append(image_mask) + + if shared.opts.return_mask_composite or shared.opts.save_mask_composite: + image_mask_composite = Image.composite(original_denoised_image.convert('RGBA').convert('RGBa'), Image.new('RGBa', image.size), images.resize_image(2, mask_for_overlay, image.width, image.height).convert('L')).convert('RGBA') + if save_samples and shared.opts.save_mask_composite: + images.save_image(image_mask_composite, p.outpath_samples, "", p.seeds[i], p.prompts[i], shared.opts.samples_format, info=infotext(i), p=p, suffix="-mask-composite") + if shared.opts.return_mask_composite: + output_images.append(image_mask_composite) + + del x_samples_ddim + + devices.torch_gc() + + if not infotexts: + infotexts.append(processing.Processed(p, []).infotext(p, 0)) + + p.color_corrections = None + + index_of_first_image = 0 + unwanted_grid_because_of_img_count = len(output_images) < 2 and shared.opts.grid_only_if_multiple + if (shared.opts.return_grid or shared.opts.grid_save) and not p.do_not_save_grid and not unwanted_grid_because_of_img_count: + grid = images.image_grid(output_images, p.batch_size) + + if shared.opts.return_grid: + text = infotext(use_main_prompt=True) + infotexts.insert(0, text) + if shared.opts.enable_pnginfo: + grid.info["parameters"] = text + output_images.insert(0, grid) + index_of_first_image = 1 + if shared.opts.grid_save: + images.save_image(grid, p.outpath_grids, "grid", p.all_seeds[0], p.all_prompts[0], shared.opts.grid_format, info=infotext(use_main_prompt=True), short_filename=not shared.opts.grid_extended_filename, p=p, grid=True) + + # disable sd webui type loras + ''' + if not p.disable_extra_networks and p.extra_network_data: + extra_networks.deactivate(p, p.extra_network_data) + ''' + devices.torch_gc() + + res = processing.Processed( + p, + images_list=output_images, + seed=p.all_seeds[0], + info=infotexts[0], + subseed=p.all_subseeds[0], + index_of_first_image=index_of_first_image, + infotexts=infotexts, + ) + + if p.scripts is not None: + p.scripts.postprocess(p, res) + + return res + +def load_model_hunyuan(checkpoint_info=None, already_loaded_state_dict=None): + from modules import sd_hijack + from diffusers import HunyuanDiT2DModel + checkpoint_info = checkpoint_info or sd_models.select_checkpoint() + + timer = Timer() + + if sd_models.model_data.sd_model: + sd_models.model_data.sd_model.to("cpu") + sd_models.model_data.sd_model = None + devices.torch_gc() + + timer.record("unload existing model") + + if already_loaded_state_dict is not None: + state_dict = already_loaded_state_dict + else: + state_dict = sd_models.get_checkpoint_state_dict(checkpoint_info, timer) + + timer.record("load weights from state dict") + + sd_model = HunyuanDiT2DModel.from_config(hunyuan_transformer_config_v12) + print("loading hunyuan DiT") + checkpoint_config = guess_dit_model(state_dict) + sd_model.used_config = checkpoint_config + if checkpoint_config == "hunyuan-original": + state_dict = convert_hunyuan_to_diffusers(state_dict) + elif "hunyuan" not in checkpoint_config: + raise ValueError("Found no hunyuan DiT model") + sd_model.load_state_dict(state_dict, strict=False) + del state_dict + + print("loading text encoder and vae") + shared.clip_l_model = BertModel.from_pretrained(shared.opts.Hunyuan_model_path,subfolder="text_encoder",torch_dtype=devices.dtype).to(devices.device) + shared.mt5_model = T5EncoderModel.from_pretrained(shared.opts.Hunyuan_model_path,subfolder="text_encoder_2",torch_dtype=devices.dtype).to(devices.device) + shared.clip_l_model.tokenizer = BertTokenizer.from_pretrained(shared.opts.Hunyuan_model_path,subfolder="tokenizer") + shared.mt5_model.tokenizer = MT5Tokenizer.from_pretrained(shared.opts.Hunyuan_model_path,subfolder="tokenizer_2") + shared.clip_l_model = sd_hijack_clip_diffusers.FrozenBertEmbedderWithCustomWords(shared.clip_l_model,sd_hijack.model_hijack) + shared.mt5_model = sd_hijack_clip_diffusers.FrozenT5EmbedderWithCustomWords(shared.mt5_model,sd_hijack.model_hijack) + shared.clip_l_model.return_masks = True + shared.mt5_model.return_masks = True + shared.vae_model = AutoencoderKL.from_pretrained(shared.opts.Hunyuan_model_path,subfolder="vae",torch_dtype=devices.dtype).to(devices.device) + + sd_model.to(devices.dtype) + sd_model.to(devices.device) + sd_model.eval() + sd_model_hash = checkpoint_info.calculate_shorthash() + sd_model.sd_model_hash = sd_model_hash + sd_model.sd_model_checkpoint = checkpoint_info.filename + sd_model.sd_checkpoint_info = checkpoint_info + sd_model.lowvram = False + sd_model.is_sd1 = False + sd_model.is_sd2 = False + sd_model.is_sdxl = False + sd_model.is_ssd = False + sd_model.is_sd3 = False + sd_model.model = None + sd_model.first_stage_model = None + sd_model.cond_stage_key = None + sd_model.cond_stage_model = None + sd_model.get_learned_conditioning = diffusers_learned_conditioning.get_learned_conditioning_hunyuan + sd_models.model_data.set_sd_model(sd_model) + sd_models.model_data.was_loaded_at_least_once = True + + script_callbacks.model_loaded_callback(sd_model) + + timer.record("scripts callbacks") + + print(f"Model loaded in {timer.summary()}.") + + return sd_model + +def reload_model_weights_hunyuan(sd_model=None, info=None, forced_reload=False): + checkpoint_info = info or sd_models.select_checkpoint() + + timer = Timer() + + if not sd_model: + sd_model = sd_models.model_data.sd_model + + if sd_model is None: # previous model load failed + current_checkpoint_info = None + else: + current_checkpoint_info = sd_model.sd_checkpoint_info + if sd_model.sd_model_checkpoint == checkpoint_info.filename and not forced_reload: + return sd_model + + sd_model.to(devices.dtype) + sd_model.to(devices.device) + if not forced_reload and sd_model is not None and sd_model.sd_checkpoint_info.filename == checkpoint_info.filename: + return sd_model + + if sd_model is not None: + sd_models.send_model_to_cpu(sd_model) + + state_dict = sd_models.get_checkpoint_state_dict(checkpoint_info, timer) + + checkpoint_config = guess_dit_model(state_dict) + if checkpoint_config == "hunyuan-original": + state_dict = convert_hunyuan_to_diffusers(state_dict) + elif "hunyuan" not in checkpoint_config: + raise ValueError("Found no hunyuan DiT model") + timer.record("find config") + + if sd_model is None or checkpoint_config != sd_model.used_config: + load_model_hunyuan(checkpoint_info, already_loaded_state_dict=state_dict) + return sd_models.model_data.sd_model + try: + sd_model.load_state_dict(state_dict, strict=False) + del state_dict + sd_model_hash = checkpoint_info.calculate_shorthash() + sd_model.sd_model_hash = sd_model_hash + sd_model.sd_model_checkpoint = checkpoint_info.filename + sd_model.sd_checkpoint_info = checkpoint_info + sd_model.lowvram = False + sd_model.is_sd1 = False + sd_model.is_sd2 = False + sd_model.is_sdxl = False + sd_model.is_ssd = False + sd_model.is_sd3 = False + sd_model.model = None + sd_model.first_stage_model = None + sd_model.cond_stage_key = None + sd_model.cond_stage_model = None + except Exception: + print("Failed to load checkpoint, restoring previous") + state_dict = sd_models.get_checkpoint_state_dict(current_checkpoint_info, timer) + sd_model.load_state_dict(state_dict, strict=False) + del state_dict + sd_model_hash = checkpoint_info.calculate_shorthash() + sd_model.sd_model_hash = sd_model_hash + sd_model.sd_model_checkpoint = checkpoint_info.filename + sd_model.sd_checkpoint_info = checkpoint_info + sd_model.lowvram = False + sd_model.is_sd1 = False + sd_model.is_sd2 = False + sd_model.is_sdxl = False + sd_model.is_ssd = False + sd_model.is_sd3 = False + sd_model.model = None + sd_model.first_stage_model = None + sd_model.cond_stage_key = None + sd_model.cond_stage_model = None + raise + finally: + script_callbacks.model_loaded_callback(sd_model) + timer.record("script callbacks") + + print(f"Weights loaded in {timer.summary()}.") + + sd_models.model_data.set_sd_model(sd_model) + + return sd_model + +class Script(scripts.Script): + + def __init__(self): + super(Script, self).__init__() + def title(self): + return 'Hunyuan DiT' + + def show(self, is_img2img): + return scripts.AlwaysVisible + + def ui(self, is_img2img): + tab = 't2i' if not is_img2img else 'i2i' + is_t2i = 'true' if not is_img2img else 'false' + uid = lambda name: f'MD-{tab}-{name}' + + with gr.Accordion('Hunyuan DiT', open=False): + with gr.Row(variant='compact') as tab_enable: + enabled = gr.Checkbox(label='Enable Hunyuan DiT', value=False, elem_id=uid('enabled')) + enabled.change( + fn=on_enable_change, + inputs=[enabled], + outputs=None + ) + return [ + enabled + ] + +def on_enable_change(enabled: bool): + if enabled: + print("Enable Hunyuan DiT") + hijack() + else: + print("Disable Hunyuan DiT") + reset() + shared.clip_l_model = unload_model(shared.clip_l_model) + shared.mt5_model = unload_model(shared.mt5_model) + shared.vae_model = unload_model(shared.vae_model) + +def reset(): + ''' unhijack inner APIs ''' + if hasattr(processing,"process_images_inner_original"): + processing.process_images_inner = processing.process_images_inner_original + if hasattr(processing.StableDiffusionProcessingTxt2Img,"sample_original"): + processing.StableDiffusionProcessingTxt2Img.sample = processing.StableDiffusionProcessingTxt2Img.sample_original + if hasattr(processing.StableDiffusionProcessingImg2Img,"sample_original"): + processing.StableDiffusionProcessingImg2Img.sample = processing.StableDiffusionProcessingImg2Img.sample_original + if hasattr(sd_models,"load_model_original"): + sd_models.load_model = sd_models.load_model_original + if hasattr(sd_models,"reload_model_weights_original"): + sd_models.reload_model_weights = sd_models.reload_model_weights_original + if hasattr(processing.StableDiffusionProcessingImg2Img,"init_img2img_original"): + processing.StableDiffusionProcessingImg2Img.init = processing.StableDiffusionProcessingImg2Img.init_img2img_original + +def hijack(): + ''' hijack inner APIs ''' + if not hasattr(processing,"process_images_inner_original"): + processing.process_images_inner_original = processing.process_images_inner + if not hasattr(processing.StableDiffusionProcessingTxt2Img,"sample_original"): + processing.StableDiffusionProcessingTxt2Img.sample_original = processing.StableDiffusionProcessingTxt2Img.sample + if not hasattr(processing.StableDiffusionProcessingImg2Img,"sample_original"): + processing.StableDiffusionProcessingImg2Img.sample_original = processing.StableDiffusionProcessingImg2Img.sample + if not hasattr(sd_models,"load_model_original"): + sd_models.load_model_original = sd_models.load_model + if not hasattr(sd_models,"reload_model_weights_original"): + sd_models.reload_model_weights_original = sd_models.reload_model_weights + if not hasattr(processing.StableDiffusionProcessingImg2Img,"init_img2img_original"): + processing.StableDiffusionProcessingImg2Img.init_img2img_original = processing.StableDiffusionProcessingImg2Img.init + processing.process_images_inner = process_images_inner_hunyuan + processing.StableDiffusionProcessingTxt2Img.sample = sample_txt2img + processing.StableDiffusionProcessingImg2Img.sample = sample_img2img + sd_models.load_model = load_model_hunyuan + sd_models.reload_model_weights = reload_model_weights_hunyuan + processing.StableDiffusionProcessingImg2Img.init = init_img2img + +def on_ui_settings(): + + shared.opts.add_option("Hunyuan_model_path", shared.OptionInfo("./models/hunyuan", "Hunyuan Model Path",section=('hunyuanDiT', "HunyuanDiT"))) + +script_callbacks.on_ui_settings(on_ui_settings)