diff --git a/Notice b/Notice index aa7b677..be83e11 100644 --- a/Notice +++ b/Notice @@ -275,6 +275,9 @@ Copyright (c) 2021 Vision and Computational Cognition Group 8. sd-vae-ft-ema Copyright (c) sd-vae-ft-ema original author and authors +9. ComfyUI-Diffusers +Copyright (c) 2023 Limitex + Terms of the MIT License: -------------------------------------------------------------------- diff --git a/README.md b/README.md index 8a8366c..6a4e757 100644 --- a/README.md +++ b/README.md @@ -25,6 +25,7 @@ This repo contains PyTorch model definitions, pre-trained weights and inference/ > [**DialogGen: Multi-modal Interactive Dialogue System for Multi-turn Text-to-Image Generation**](https://arxiv.org/abs/2403.08857)
## 🔥🔥🔥 News!! +* Jun 06, 2024: :tada: Hunyuan-DiT is now available in ComfyUI. Please check [ComfyUI](#using-comfyui) for more details. * Jun 06, 2024: 🚀 We introduce Distillation version for Hunyuan-DiT acceleration, which achieves **50%** acceleration on NVIDIA GPUs. Please check [Tencent-Hunyuan/Distillation](https://huggingface.co/Tencent-Hunyuan/Distillation) for more details. * Jun 05, 2024: 🤗 Hunyuan-DiT is now available in 🤗 Diffusers! Please check the [example](#using--diffusers) below. * Jun 04, 2024: :globe_with_meridians: Support Tencent Cloud links to download the pretrained models! Please check the [links](#-download-pretrained-models) below. @@ -73,7 +74,7 @@ or multi-turn language interactions to create the picture. - [X] Web Demo (Gradio) - [x] Multi-turn T2I Demo (Gradio) - [X] Cli Demo -- [ ] ComfyUI +- [X] ComfyUI - [X] Diffusers - [ ] WebUI @@ -94,6 +95,7 @@ or multi-turn language interactions to create the picture. - [Using Diffusers](#using--diffusers) - [Using Command Line](#using-command-line) - [More Configurations](#more-configurations) + - [Using ComfyUI](#using-comfyui) - [🚀 Acceleration (for Linux)](#-acceleration-for-linux) - [🔗 BibTeX](#-bibtex) @@ -389,6 +391,46 @@ We list some more useful configurations for easy usage: | `--load-key` | ema | Load the student model or EMA model (ema or module) | | `--load-4bit` | Fasle | Load DialogGen model with 4bit quantization | +### Using ComfyUI + +We provide several commands to quick start: + +```shell +# Download comfyui code +git clone https://github.com/comfyanonymous/ComfyUI.git + +# Install torch, torchvision, torchaudio +pip install torch==2.0.1 torchvision==0.15.2 torchaudio==2.0.2 --index-url https://download.pytorch.org/whl/cu117 + +# Install Comfyui essential python package +cd ComfyUI +pip install -r requirements.txt + +# ComfyUI has been successfully installed! + +# Download model weight as before or link the existing model folder to ComfyUI. +python -m pip install "huggingface_hub[cli]" +mkdir models/hunyuan +huggingface-cli download Tencent-Hunyuan/HunyuanDiT --local-dir ./models/hunyuan/ckpts + +# Move to the ComfyUI custom_nodes folder and copy comfyui-hydit folder from HunyuanDiT Repo. +cd custom_nodes +cp -r ${HunyuanDiT}/comfyui-hydit ./ +cd comfyui-hydit + +# Install some essential python Package. +pip install -r requirements.txt + +# Our tool has been successfully installed! + +# Go to ComfyUI main folder +cd ../.. +# Run the ComfyUI Lauch command +python main.py --listen --port 80 + +# Running ComfyUI successfully! +``` +More details can be found in [ComfyUI README](comfyui-hydit/README.md) ## 🚀 Acceleration (for Linux) diff --git a/comfyui-hydit/LICENSE b/comfyui-hydit/LICENSE new file mode 100644 index 0000000..150218d --- /dev/null +++ b/comfyui-hydit/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2023 Limitex + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/comfyui-hydit/README.md b/comfyui-hydit/README.md new file mode 100644 index 0000000..e3634be --- /dev/null +++ b/comfyui-hydit/README.md @@ -0,0 +1,95 @@ +# comfyui-hydit + +This repository contains a customized node and workflow designed specifically for HunYuan DIT. The official tests conducted on DDPM, DDIM, and DPMMS have consistently yielded results that align with those obtained through the Diffusers library. However, it's important to note that we cannot assure the consistency of results from other ComfyUI native samplers with the Diffusers inference. We cordially invite users to explore our workflow and are open to receiving any inquiries or suggestions you may have. + +## Overview + + +### Workflow text2image + +![Workflow](img/txt2img_v2.png) + +[workflow_diffusers](workflow/hunyuan_diffusers_api.json) file for HunyuanDiT txt2image with diffusers backend. +[workflow_ksampler](workflow/hunyuan_ksampler_api.json) file for HunyuanDiT txt2image with ksampler backend. + + +## Usage + +We provide several commands to quick start: + +```shell +# Download comfyui code +git clone https://github.com/comfyanonymous/ComfyUI.git + +# Install torch, torchvision, torchaudio +pip install torch==2.0.1 torchvision==0.15.2 torchaudio==2.0.2 --index-url https://download.pytorch.org/whl/cu117 + +# Install Comfyui essential python package +cd ComfyUI +pip install -r requirements.txt + +# ComfyUI has been successfully installed! + +# Download model weight as before or link the existing model folder to ComfyUI. +python -m pip install "huggingface_hub[cli]" +mkdir models/hunyuan +huggingface-cli download Tencent-Hunyuan/HunyuanDiT --local-dir ./models/hunyuan/ckpts + +# Move to the ComfyUI custom_nodes folder and copy comfyui-hydit folder from HunyuanDiT Repo. +cd custom_nodes +cp -r ${HunyuanDiT}/comfyui-hydit ./ +cd comfyui-hydit + +# Install some essential python Package. +pip install -r requirements.txt + +# Our tool has been successfully installed! + +# Go to ComfyUI main folder +cd ../.. +# Run the ComfyUI Lauch command +python main.py --listen --port 80 + +# Running ComfyUI successfully! +``` + + + +## Custom Node +Below I'm trying to document all the nodes, thanks for some good work[[1]](#1)[[2]](#2). +#### HunYuan Pipeline Loader +- Loads the full stack of models needed for HunYuanDiT. +- **pipeline_folder_name** is the official weight folder path for hunyuan dit including clip_text_encoder, model, mt5, sdxl-vae-fp16-fix and tokenizer. +- **model_name** is the weight list of comfyui checkpoint folder. +- **vae_name** is the weight list of comfyui vae folder. +- **backend** "diffusers" means using diffusers as the backend, while "ksampler" means using comfyui ksampler for the backend. +- **PIPELINE** is the instance of StableDiffusionPipeline. +- **MODEL** is the instance of comfyui MODEL. +- **CLIP** is the instance of comfyui CLIP. +- **VAE** is the instance of comfyui VAE. + +#### HunYuan Scheduler Loader +- Loads the scheduler algorithm for HunYuanDiT. +- **Input** is the algorithm name including ddpm, ddim and dpmms. +- **Output** is the instance of diffusers.schedulers. + +#### HunYuan Model Makeup +- Assemble the models and scheduler module. +- **Input** is the instance of StableDiffusionPipeline and diffusers.schedulers. +- **Output** is the updated instance of StableDiffusionPipeline. + +#### HunYuan Clip Text Encode +- Assemble the models and scheduler module. +- **Input** is the string of positive and negative prompts. +- **Output** is the converted string for model. + +#### HunYuan Sampler +- Similar with KSampler in ComfyUI. +- **Input** is the instance of StableDiffusionPipeline and some hyper-parameters for sampling. +- **Output** is the generated image. + +## Reference +[1] +https://github.com/Limitex/ComfyUI-Diffusers +[2] +https://github.com/Tencent/HunyuanDiT/pull/59 diff --git a/comfyui-hydit/__init__.py b/comfyui-hydit/__init__.py new file mode 100644 index 0000000..ee645d6 --- /dev/null +++ b/comfyui-hydit/__init__.py @@ -0,0 +1,4 @@ +from .nodes import * +#aa = DiffusersSampler() +#print(aa) +__all__ = ["NODE_CLASS_MAPPINGS", "NODE_DISPLAY_NAME_MAPPINGS"] diff --git a/comfyui-hydit/clip.py b/comfyui-hydit/clip.py new file mode 100644 index 0000000..4549f86 --- /dev/null +++ b/comfyui-hydit/clip.py @@ -0,0 +1,113 @@ +import comfy.supported_models_base +import comfy.latent_formats +import comfy.model_patcher +import comfy.model_base +import comfy.utils +from .hydit.modules.text_encoder import MT5Embedder +from transformers import BertModel, BertTokenizer +import torch +import os + +class CLIP: + def __init__(self, root): + self.device = "cuda" if torch.cuda.is_available() else "cpu" + text_encoder_path = os.path.join(root,"clip_text_encoder") + clip_text_encoder = BertModel.from_pretrained(str(text_encoder_path), False, revision=None).to(self.device) + tokenizer_path = os.path.join(root,"tokenizer") + self.tokenizer = HyBertTokenizer(tokenizer_path) + t5_text_encoder_path = os.path.join(root,'mt5') + embedder_t5 = MT5Embedder(t5_text_encoder_path, torch_dtype=torch.float16, max_length=256) + self.tokenizer_t5 = HyT5Tokenizer(embedder_t5.tokenizer, max_length=embedder_t5.max_length) + self.embedder_t5 = embedder_t5.model + + self.cond_stage_model = clip_text_encoder + + def tokenize(self, text): + tokens = self.tokenizer.tokenize(text) + t5_tokens = self.tokenizer_t5.tokenize(text) + tokens.update(t5_tokens) + return tokens + + def tokenize_t5(self, text): + return self.tokenizer_t5.tokenize(text) + + def encode_from_tokens(self, tokens, return_pooled=False): + attention_mask = tokens['attention_mask'].to(self.device) + with torch.no_grad(): + prompt_embeds = self.cond_stage_model( + tokens['text_input_ids'].to(self.device), + attention_mask=attention_mask + ) + prompt_embeds = prompt_embeds[0] + t5_attention_mask = tokens['t5_attention_mask'].to(self.device) + with torch.no_grad(): + t5_prompt_cond = self.embedder_t5( + tokens['t5_text_input_ids'].to(self.device), + attention_mask=t5_attention_mask + ) + t5_embeds = t5_prompt_cond[0] + + addit_embeds = { + "t5_embeds": t5_embeds, + "attention_mask": attention_mask.float(), + "t5_attention_mask": t5_attention_mask.float() + } + prompt_embeds.addit_embeds = addit_embeds + + if return_pooled: + return prompt_embeds, None + else: + return prompt_embeds + +class HyBertTokenizer: + def __init__(self, tokenizer_path=None, max_length=77, truncation=True, return_attention_mask=True, device='cpu'): + self.tokenizer = BertTokenizer.from_pretrained(str(tokenizer_path)) + self.max_length = self.tokenizer.model_max_length or max_length + self.truncation = truncation + self.return_attention_mask = return_attention_mask + self.device = device + + def tokenize(self, text:str): + text_inputs = self.tokenizer( + text, + padding="max_length", + max_length=self.max_length, + truncation=self.truncation, + return_attention_mask=self.return_attention_mask, + add_special_tokens = True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + attention_mask = text_inputs.attention_mask + tokens = { + 'text_input_ids': text_input_ids, + 'attention_mask': attention_mask + } + return tokens + +class HyT5Tokenizer: + def __init__(self, tokenizer, max_length=77, truncation=True, return_attention_mask=True, device='cpu'): + self.tokenizer = tokenizer + self.max_length = max_length + self.truncation = truncation + self.return_attention_mask = return_attention_mask + self.device = device + + def tokenize(self, text:str): + text_inputs = self.tokenizer( + text, + padding="max_length", + max_length=self.max_length, + truncation=self.truncation, + return_attention_mask=self.return_attention_mask, + add_special_tokens = True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + attention_mask = text_inputs.attention_mask + tokens = { + 't5_text_input_ids': text_input_ids, + 't5_attention_mask': attention_mask + } + return tokens + diff --git a/comfyui-hydit/constant.py b/comfyui-hydit/constant.py new file mode 100644 index 0000000..7f778c7 --- /dev/null +++ b/comfyui-hydit/constant.py @@ -0,0 +1,6 @@ +import os +from .hydit.constants import SAMPLER_FACTORY + +base_path = os.path.dirname(os.path.realpath(__file__)) +HUNYUAN_PATH = os.path.join(base_path, "..", "..", "models", "hunyuan") +SCHEDULERS_hunyuan = list(SAMPLER_FACTORY.keys()) \ No newline at end of file diff --git a/comfyui-hydit/dit.py b/comfyui-hydit/dit.py new file mode 100644 index 0000000..afb0440 --- /dev/null +++ b/comfyui-hydit/dit.py @@ -0,0 +1,95 @@ +import comfy.supported_models_base +import comfy.latent_formats +import comfy.model_patcher +import comfy.model_base +import comfy.utils +from comfy import model_management +from .supported_dit_models import HunYuan_DiT, HYDiT_Model, ModifiedHunYuanDiT +from .clip import CLIP +import os +import folder_paths +import torch + +sampling_settings = { + "beta_schedule" : "linear", + "linear_start" : 0.00085, + "linear_end" : 0.03, + "timesteps" : 1000, +} + +hydit_conf = { + "G/2": { # Seems to be the main one + "unet_config": { + "depth" : 40, + "num_heads" : 16, + "patch_size" : 2, + "hidden_size" : 1408, + "mlp_ratio" : 4.3637, + "input_size": (1024//8, 1024//8), + #"disable_unet_model_creation": True, + }, + "sampling_settings" : sampling_settings, + }, +} + +def load_dit(model_path, output_clip=True, output_model=True, output_vae=True, MODEL_PATH = None, VAE_PATH = None): + if MODEL_PATH: + state_dict = comfy.utils.load_torch_file(MODEL_PATH) + else: + state_dict = comfy.utils.load_torch_file(os.path.join(model_path, "t2i", "model", "pytorch_model_ema.pt")) + + + state_dict = state_dict.get("model", state_dict) + parameters = comfy.utils.calculate_parameters(state_dict) + unet_dtype = model_management.unet_dtype(model_params=parameters) + load_device = comfy.model_management.get_torch_device() + offload_device = comfy.model_management.unet_offload_device() + clip = None, + vae = None + model_patcher = None + + # ignore fp8/etc and use directly for now + manual_cast_dtype = model_management.unet_manual_cast(unet_dtype, load_device) + root = os.path.join(model_path, "t2i") + if manual_cast_dtype: + print(f"DiT: falling back to {manual_cast_dtype}") + unet_dtype = manual_cast_dtype + + #model_conf["unet_config"]["num_classes"] = state_dict["y_embedder.embedding_table.weight"].shape[0] - 1 # adj. for empty + + if output_model: + model_conf = HunYuan_DiT(hydit_conf["G/2"]) + model = HYDiT_Model( + model_conf, + model_type=comfy.model_base.ModelType.V_PREDICTION, + device=model_management.get_torch_device() + ) + + #print(model_conf.unet_config) + #assert(0) + + model.diffusion_model = ModifiedHunYuanDiT(model_conf.dit_conf, **model_conf.unet_config).half().to(load_device) + + model.diffusion_model.load_state_dict(state_dict) + model.diffusion_model.eval() + model.diffusion_model.to(unet_dtype) + + model_patcher = comfy.model_patcher.ModelPatcher( + model, + load_device = load_device, + offload_device = offload_device, + current_device = "cpu", + ) + if output_clip: + clip = CLIP(root) + + if output_vae: + if VAE_PATH: + vae_path = VAE_PATH + else: + vae_path = os.path.join(root, "sdxl-vae-fp16-fix", "diffusion_pytorch_model.safetensors") + #print(vae_path) + sd = comfy.utils.load_torch_file(vae_path) + vae = comfy.sd.VAE(sd=sd) + + return (model_patcher, clip, vae) diff --git a/comfyui-hydit/hydit/__init__.py b/comfyui-hydit/hydit/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/comfyui-hydit/hydit/config.py b/comfyui-hydit/hydit/config.py new file mode 100644 index 0000000..47b7761 --- /dev/null +++ b/comfyui-hydit/hydit/config.py @@ -0,0 +1,69 @@ +import argparse + +from .constants import * +from .modules.models import HUNYUAN_DIT_CONFIG + + +def get_args(default_args=None): + parser_hunyuan = argparse.ArgumentParser() + #print(parser_hunyuan) + + # Basic + parser_hunyuan.add_argument("--prompt", type=str, default="现实主义风格,画面主要描述一个巴洛克风格的花瓶,带有金色的装饰边框,花瓶上盛开着各种色彩鲜艳的花,白色背景", help="The prompt for generating images.") + parser_hunyuan.add_argument("--model-root", type=str, default="ckpts", help="Model root path.") + parser_hunyuan.add_argument("--image-size", type=int, nargs='+', default=[1024, 1024], + help='Image size (h, w). If a single value is provided, the image will be treated to ' + '(value, value).') + parser_hunyuan.add_argument("--infer-mode", type=str, choices=["fa", "torch", "trt"], default="torch", + help="Inference mode") + + # HunYuan-DiT + parser_hunyuan.add_argument("--model", type=str, choices=list(HUNYUAN_DIT_CONFIG.keys()), default='DiT-g/2') + parser_hunyuan.add_argument("--norm", type=str, default="layer", help="Normalization layer type") + parser.add_argument("--load-key", type=str, choices=["ema", "module", "distill"], default="ema", help="Load model key for HunYuanDiT checkpoint.") + parser_hunyuan.add_argument('--size-cond', type=int, nargs='+', default=[1024, 1024], + help="Size condition used in sampling. 2 values are required for height and width. " + "If a single value is provided, the image will be treated to (value, value).") + parser_hunyuan.add_argument("--cfg-scale", type=float, default=6.0, help="Guidance scale for classifier-free.") + + # Prompt enhancement + parser_hunyuan.add_argument("--enhance", action="store_true", help="Enhance prompt with dialoggen.") + parser_hunyuan.add_argument("--no-enhance", dest="enhance", action="store_false") + parser_hunyuan.add_argument("--load-4bit", help="load DialogGen model with 4bit quantization.", action="store_true") + parser_hunyuan.set_defaults(enhance=True) + + # Diffusion + parser_hunyuan.add_argument("--learn-sigma", action="store_true", help="Learn extra channels for sigma.") + parser_hunyuan.add_argument("--no-learn-sigma", dest="learn_sigma", action="store_false") + parser_hunyuan.set_defaults(learn_sigma=True) + parser_hunyuan.add_argument("--predict-type", type=str, choices=list(PREDICT_TYPE), default="v_prediction", + help="Diffusion predict type") + parser_hunyuan.add_argument("--noise-schedule", type=str, choices=list(NOISE_SCHEDULES), default="scaled_linear", + help="Noise schedule") + parser_hunyuan.add_argument("--beta-start", type=float, default=0.00085, help="Beta start value") + parser_hunyuan.add_argument("--beta-end", type=float, default=0.03, help="Beta end value") + + # Text condition + parser_hunyuan.add_argument("--text-states-dim", type=int, default=1024, help="Hidden size of CLIP text encoder.") + parser_hunyuan.add_argument("--text-len", type=int, default=77, help="Token length of CLIP text encoder output.") + parser_hunyuan.add_argument("--text-states-dim-t5", type=int, default=2048, help="Hidden size of CLIP text encoder.") + parser_hunyuan.add_argument("--text-len-t5", type=int, default=256, help="Token length of T5 text encoder output.") + parser_hunyuan.add_argument("--negative", type=str, default="错误的眼睛,糟糕的人脸,毁容,糟糕的艺术,变形,多余的肢体,模糊的颜色,模糊,重复,病态,残缺,", help="Negative prompt.") + + # Acceleration + parser_hunyuan.add_argument("--use_fp16", action="store_true", help="Use FP16 precision.") + parser_hunyuan.add_argument("--no-fp16", dest="use_fp16", action="store_false") + parser_hunyuan.set_defaults(use_fp16=True) + + # Sampling + parser_hunyuan.add_argument("--batch-size", type=int, default=1, help="Per-GPU batch size") + parser_hunyuan.add_argument("--sampler", type=str, choices=SAMPLER_FACTORY, default="ddim", help="Diffusion sampler") + parser_hunyuan.add_argument("--infer-steps", type=int, default=30, help="Inference steps") + parser_hunyuan.add_argument('--seed', type=int, default=666, help="A seed for all the prompts.") + + # App + parser_hunyuan.add_argument("--lang", type=str, default="zh", choices=["zh", "en"], help="Language") + + args = parser_hunyuan.parse_known_args() + + return args diff --git a/comfyui-hydit/hydit/constants.py b/comfyui-hydit/hydit/constants.py new file mode 100644 index 0000000..f2cdf81 --- /dev/null +++ b/comfyui-hydit/hydit/constants.py @@ -0,0 +1,62 @@ +# ======================================================= +NOISE_SCHEDULES = { + "linear", + "scaled_linear", + "squaredcos_cap_v2", +} + +PREDICT_TYPE = { + "epsilon", + "sample", + "v_prediction", +} + +# ======================================================= +NEGATIVE_PROMPT = '错误的眼睛,糟糕的人脸,毁容,糟糕的艺术,变形,多余的肢体,模糊的颜色,模糊,重复,病态,残缺,' + + +# ======================================================= +# Constants about models +# ======================================================= + +SAMPLER_FACTORY = { + 'ddpm': { + 'scheduler': 'DDPMScheduler', + 'name': 'DDPM', + 'kwargs': { + 'steps_offset': 1, + 'clip_sample': False, + 'clip_sample_range': 1.0, + 'beta_schedule': 'scaled_linear', + 'beta_start': 0.00085, + 'beta_end': 0.03, + 'prediction_type': 'v_prediction', + } + }, + 'ddim': { + 'scheduler': 'DDIMScheduler', + 'name': 'DDIM', + 'kwargs': { + 'steps_offset': 1, + 'clip_sample': False, + 'clip_sample_range': 1.0, + 'beta_schedule': 'scaled_linear', + 'beta_start': 0.00085, + 'beta_end': 0.03, + 'prediction_type': 'v_prediction', + } + }, + 'dpmms': { + 'scheduler': 'DPMSolverMultistepScheduler', + 'name': 'DPMMS', + 'kwargs': { + 'beta_schedule': 'scaled_linear', + 'beta_start': 0.00085, + 'beta_end': 0.03, + 'prediction_type': 'v_prediction', + 'trained_betas': None, + 'solver_order': 2, + 'algorithm_type': 'dpmsolver++', + } + }, +} diff --git a/comfyui-hydit/hydit/diffusion/__init__.py b/comfyui-hydit/hydit/diffusion/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/comfyui-hydit/hydit/diffusion/pipeline.py b/comfyui-hydit/hydit/diffusion/pipeline.py new file mode 100644 index 0000000..cfa07af --- /dev/null +++ b/comfyui-hydit/hydit/diffusion/pipeline.py @@ -0,0 +1,865 @@ +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from typing import Any, Callable, Dict, List, Optional, Union + +import PIL +import numpy as np +import torch +import torchvision.transforms as T +from diffusers.configuration_utils import FrozenDict +from diffusers.image_processor import VaeImageProcessor +from diffusers.loaders import FromSingleFileMixin, LoraLoaderMixin, TextualInversionLoaderMixin +from diffusers.models import AutoencoderKL, UNet2DConditionModel +from diffusers.models.lora import adjust_lora_scale_text_encoder +from diffusers.pipelines.pipeline_utils import DiffusionPipeline +from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput +from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker +from diffusers.schedulers import KarrasDiffusionSchedulers +from diffusers.utils import ( + PIL_INTERPOLATION, + deprecate, + logging, + replace_example_docstring, +) +from diffusers.utils.torch_utils import randn_tensor +from transformers import BertModel, BertTokenizer +from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer + +from ..modules.models import HunYuanDiT +import pdb + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import requests + >>> import torch + >>> from PIL import Image + >>> from io import BytesIO + + >>> from diffusers import StableDiffusionImg2ImgPipeline + + >>> device = "cuda" + >>> model_id_or_path = "runwayml/stable-diffusion-v1-5" + >>> pipe = StableDiffusionImg2ImgPipeline.from_pretrained(model_id_or_path, torch_dtype=torch.float16) + >>> pipe = pipe.to(device) + + >>> url = "https://raw.githubusercontent.com/CompVis/stable-diffusion/main/assets/stable-samples/img2img/sketch-mountains-input.jpg" + + >>> response = requests.get(url) + >>> init_image = Image.open(BytesIO(response.content)).convert("RGB") + >>> init_image = init_image.resize((768, 512)) + + >>> prompt = "A fantasy landscape, trending on artstation" + + >>> images = pipe(prompt=prompt, image=init_image, strength=0.75, guidance_scale=7.5).images + >>> images[0].save("fantasy_landscape.png") + ``` +""" +def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0): + """ + Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and + Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4 + """ + std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True) + std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True) + # rescale the results from guidance (fixes overexposure) + noise_pred_rescaled = noise_cfg * (std_text / std_cfg) + # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images + noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg + return noise_cfg + +def preprocess(image): + deprecation_message = "The preprocess method is deprecated and will be removed in diffusers 1.0.0. Please use VaeImageProcessor.preprocess(...) instead" + deprecate("preprocess", "1.0.0", deprecation_message, standard_warn=False) + if isinstance(image, torch.Tensor): + return image + elif isinstance(image, PIL.Image.Image): + image = [image] + + if isinstance(image[0], PIL.Image.Image): + w, h = image[0].size + w, h = (x - x % 8 for x in (w, h)) # resize to integer multiple of 8 + + image = [np.array(i.resize((w, h), resample=PIL_INTERPOLATION["lanczos"]))[None, :] for i in image] + image = np.concatenate(image, axis=0) + image = np.array(image).astype(np.float32) / 255.0 + image = image.transpose(0, 3, 1, 2) + image = 2.0 * image - 1.0 + image = torch.from_numpy(image) + elif isinstance(image[0], torch.Tensor): + image = torch.cat(image, dim=0) + return image + + +class StableDiffusionPipeline( + DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin, FromSingleFileMixin +): + r""" + Pipeline for text-guided image-to-image generation using Stable Diffusion. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods + implemented for all pipelines (downloading, saving, running on a particular device, etc.). + + The pipeline also inherits the following loading methods: + - [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings + - [`~loaders.LoraLoaderMixin.load_lora_weights`] for loading LoRA weights + - [`~loaders.LoraLoaderMixin.save_lora_weights`] for saving LoRA weights + - [`~loaders.FromSingleFileMixin.from_single_file`] for loading `.ckpt` files + + Args: + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations. + text_encoder (Optional[`~transformers.BertModel`, `~transformers.CLIPTextModel`]): + Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)). + tokenizer (Optional[`~transformers.BertTokenizer`, `~transformers.CLIPTokenizer`]): + A `BertTokenizer` or `CLIPTokenizer` to tokenize text. + unet (Optional[`HunYuanDiT`, `UNet2DConditionModel`]): + A `UNet2DConditionModel` to denoise the encoded image latents. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of + [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. + safety_checker ([`StableDiffusionSafetyChecker`]): + Classification module that estimates whether generated images could be considered offensive or harmful. + Please refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for more details + about a model's potential harms. + feature_extractor ([`~transformers.CLIPImageProcessor`]): + A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`. + """ + model_cpu_offload_seq = "text_encoder->unet->vae" + _optional_components = ["safety_checker", "feature_extractor"] + _exclude_from_cpu_offload = ["safety_checker"] + + def __init__( + self, + vae: AutoencoderKL, + text_encoder: Union[BertModel, CLIPTextModel], + tokenizer: Union[BertTokenizer, CLIPTokenizer], + unet: Union[HunYuanDiT, UNet2DConditionModel], + scheduler: KarrasDiffusionSchedulers, + safety_checker: StableDiffusionSafetyChecker, + feature_extractor: CLIPImageProcessor, + requires_safety_checker: bool = True, + progress_bar_config: Dict[str, Any] = None, + embedder_t5=None, + infer_mode='torch', + ): + super().__init__() + + # ======================================================== + self.embedder_t5 = embedder_t5 + self.infer_mode = infer_mode + + # ======================================================== + if progress_bar_config is None: + progress_bar_config = {} + if not hasattr(self, '_progress_bar_config'): + self._progress_bar_config = {} + self._progress_bar_config.update(progress_bar_config) + # ======================================================== + + if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1: + deprecation_message = ( + f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`" + f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure " + "to update the config accordingly as leaving `steps_offset` might led to incorrect results" + " in future versions. If you have downloaded this checkpoint from the Hugging Face Hub," + " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`" + " file" + ) + deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False) + new_config = dict(scheduler.config) + new_config["steps_offset"] = 1 + scheduler._internal_dict = FrozenDict(new_config) + + if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True: + deprecation_message = ( + f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`." + " `clip_sample` should be set to False in the configuration file. Please make sure to update the" + " config accordingly as not setting `clip_sample` in the config might lead to incorrect results in" + " future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very" + " nice if you could open a Pull request for the `scheduler/scheduler_config.json` file" + ) + deprecate("clip_sample not set", "1.0.0", deprecation_message, standard_warn=False) + new_config = dict(scheduler.config) + new_config["clip_sample"] = False + scheduler._internal_dict = FrozenDict(new_config) + + if safety_checker is None and requires_safety_checker: + logger.warning( + f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" + " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered" + " results in services or applications open to the public. Both the diffusers team and Hugging Face" + " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling" + " it only for use-cases that involve analyzing network behavior or auditing its results. For more" + " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." + ) + + if safety_checker is not None and feature_extractor is None: + raise ValueError( + "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety" + " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." + ) + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + unet=unet, + scheduler=scheduler, + safety_checker=safety_checker, + feature_extractor=feature_extractor, + ) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) + self.register_to_config(requires_safety_checker=requires_safety_checker) + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt + def _encode_prompt( + self, + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt=None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + lora_scale: Optional[float] = None, + ): + deprecation_message = "`_encode_prompt()` is deprecated and it will be removed in a future version. Use `encode_prompt()` instead. Also, be aware that the output format changed from a concatenated tensor to a tuple." + deprecate("_encode_prompt()", "1.0.0", deprecation_message, standard_warn=False) + + prompt_embeds_tuple = self.encode_prompt( + prompt=prompt, + device=device, + num_images_per_prompt=num_images_per_prompt, + do_classifier_free_guidance=do_classifier_free_guidance, + negative_prompt=negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + lora_scale=lora_scale, + ) + + # concatenate for backwards comp + prompt_embeds = torch.cat([prompt_embeds_tuple[1], prompt_embeds_tuple[0]]) + + return prompt_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_prompt + def encode_prompt( + self, + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt=None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + lora_scale: Optional[float] = None, + embedder=None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + do_classifier_free_guidance (`bool`): + whether to use classifier free guidance or not + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + lora_scale (`float`, *optional*): + A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. + embedder: + T5 embedder (including text encoder and tokenizer) + """ + if embedder is None: + text_encoder = self.text_encoder + tokenizer = self.tokenizer + max_length = self.tokenizer.model_max_length + else: + text_encoder = embedder.model + tokenizer = embedder.tokenizer + max_length = embedder.max_length + + #pdb.set_trace() + + # set lora scale so that monkey patched LoRA + # function of text encoder can correctly access it + if lora_scale is not None and isinstance(self, LoraLoaderMixin): + self._lora_scale = lora_scale + + # dynamically adjust the LoRA scale + adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) + + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + # textual inversion: procecss multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + prompt = self.maybe_convert_prompt(prompt, tokenizer) + + text_inputs = tokenizer( + prompt, + padding="max_length", + max_length=max_length, + truncation=True, + return_attention_mask=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( + text_input_ids, untruncated_ids + ): + removed_text = tokenizer.batch_decode( + untruncated_ids[:, tokenizer.model_max_length - 1 : -1] + ) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {tokenizer.model_max_length} tokens: {removed_text}" + ) + + attention_mask = text_inputs.attention_mask.to(device) + prompt_embeds = text_encoder( + text_input_ids.to(device), + attention_mask=attention_mask, + ) + prompt_embeds = prompt_embeds[0] + attention_mask = attention_mask.repeat(num_images_per_prompt, 1) + else: + attention_mask = None + + if text_encoder is not None: + prompt_embeds_dtype = text_encoder.dtype + elif self.unet is not None: + prompt_embeds_dtype = self.unet.dtype + else: + prompt_embeds_dtype = prompt_embeds.dtype + + prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device) + + bs_embed, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) + + # get unconditional embeddings for classifier free guidance + if do_classifier_free_guidance and negative_prompt_embeds is None: + uncond_tokens: List[str] + if negative_prompt is None: + uncond_tokens = [""] * batch_size + elif prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif isinstance(negative_prompt, str): + uncond_tokens = [negative_prompt] + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = negative_prompt + + # textual inversion: procecss multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + uncond_tokens = self.maybe_convert_prompt(uncond_tokens, tokenizer) + + max_length = prompt_embeds.shape[1] + uncond_input = tokenizer( + uncond_tokens, + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="pt", + ) + + uncond_attention_mask = uncond_input.attention_mask.to(device) + negative_prompt_embeds = text_encoder( + uncond_input.input_ids.to(device), + attention_mask=uncond_attention_mask, + ) + negative_prompt_embeds = negative_prompt_embeds[0] + uncond_attention_mask = uncond_attention_mask.repeat(num_images_per_prompt, 1) + else: + uncond_attention_mask = None + + if do_classifier_free_guidance: + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] + + negative_prompt_embeds = negative_prompt_embeds.to(dtype=prompt_embeds_dtype, device=device) + + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + return prompt_embeds, negative_prompt_embeds, attention_mask, uncond_attention_mask + + def _convert_to_rgb(self, image): + return image.convert('RGB') + + def image_transform(self, image_size=224): + transform = T.Compose([ + T.Resize((image_size, image_size), interpolation=T.InterpolationMode.BICUBIC), + self._convert_to_rgb, + T.ToTensor(), + T.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)), + ]) + return transform + + def encode_img(self, img, device, do_classifier_free_guidance): + # print('len', len(img)) + # print('img', img.size) + img = img[0] # TODO: support batch processing + image_preprocess = self.image_transform(224) + img_for_clip = image_preprocess(img) + # print('img_for_clip', img_for_clip.shape) + img_for_clip = img_for_clip.unsqueeze(0) + img_clip_embedding = self.img_encoder(img_for_clip.to(device)).to(dtype=torch.float16) + # print('img_clip_embedding_1_type', img_clip_embedding.dtype) + if do_classifier_free_guidance: + negative_img_clip_embedding = torch.zeros_like(img_clip_embedding) + return img_clip_embedding, negative_img_clip_embedding + + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker + def run_safety_checker(self, image, device, dtype): + if self.safety_checker is None: + has_nsfw_concept = None + else: + if torch.is_tensor(image): + feature_extractor_input = self.image_processor.postprocess(image, output_type="pil") + else: + feature_extractor_input = self.image_processor.numpy_to_pil(image) + safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device) + image, has_nsfw_concept = self.safety_checker( + images=image, clip_input=safety_checker_input.pixel_values.to(dtype) + ) + return image, has_nsfw_concept + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents + def decode_latents(self, latents): + deprecation_message = "The decode_latents method is deprecated and will be removed in 1.0.0. Please use VaeImageProcessor.postprocess(...) instead" + deprecate("decode_latents", "1.0.0", deprecation_message, standard_warn=False) + + latents = 1 / self.vae.config.scaling_factor * latents + image = self.vae.decode(latents, return_dict=False)[0] + image = (image / 2 + 0.5).clamp(0, 1) + # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 + image = image.cpu().permute(0, 2, 3, 1).float().numpy() + return image + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + def check_inputs( + self, + prompt, + height, + width, + callback_steps, + negative_prompt=None, + prompt_embeds=None, + negative_prompt_embeds=None, + ): + if height % 8 != 0 or width % 8 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + + if (callback_steps is None) or ( + callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) + ): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" + f" {type(callback_steps)}." + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + def get_timesteps(self, num_inference_steps, strength, device): + # get the original timestep using init_timestep + init_timestep = min(int(num_inference_steps * strength), num_inference_steps) + + t_start = max(num_inference_steps - init_timestep, 0) + timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :] + + return timesteps, num_inference_steps - t_start + + def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): + shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + latents = latents.to(device) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma + return latents + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + height: int, + width: int, + prompt: Union[str, List[str]] = None, + num_inference_steps: Optional[int] = 50, + guidance_scale: Optional[float] = 7.5, + negative_prompt: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: Optional[int] = 1, + eta: Optional[float] = 0.0, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + prompt_embeds_t5: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds_t5: Optional[torch.FloatTensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + callback: Optional[Callable[[int, int, torch.FloatTensor, torch.FloatTensor], None]] = None, + callback_steps: int = 1, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + guidance_rescale: float = 0.0, + image_meta_size: Optional[torch.LongTensor] = None, + style: Optional[torch.LongTensor] = None, + progress: bool = True, + use_fp16: bool = False, + freqs_cis_img: Optional[tuple] = None, + learn_sigma: bool = True, + ): + r""" + The call function to the pipeline for generation. + + Args: + height (`int`): + The height in pixels of the generated image. + width (`int`): + The width in pixels of the generated image. + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`. + image (`torch.FloatTensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.FloatTensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`): + `Image`, numpy array or tensor representing an image batch to be used as the starting point. For both + numpy array and pytorch tensor, the expected value range is between `[0, 1]` If it's a tensor or a list + or tensors, the expected shape should be `(B, C, H, W)` or `(C, H, W)`. If it is a numpy array or a + list of arrays, the expected shape should be `(B, H, W, C)` or `(H, W, C)` It can also accept image + latents as `image`, but if passing latents directly it is not encoded again. + strength (`float`, *optional*, defaults to 1.0): + Indicates extent to transform the reference `image`. Must be between 0 and 1. `image` is used as a + starting point and more noise is added the higher the `strength`. The number of denoising steps depends + on the amount of noise initially added. When `strength` is 1, added noise is maximum and the denoising + process runs for the full number of iterations specified in `num_inference_steps`. A value of 1 + essentially ignores `image`. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. This parameter is modulated by `strength`. + guidance_scale (`float`, *optional*, defaults to 7.5): + A higher guidance scale value encourages the model to generate images closely linked to the text + `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide what to not include in image generation. If not defined, you need to + pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`). + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies + to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make + generation deterministic. + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not + provided, text embeddings are generated from the `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If + not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generated image. Choose between `PIL.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a + plain tuple. + callback (`Callable`, *optional*): + A function that calls every `callback_steps` steps during inference. The function is called with the + following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor, + pred_x0: torch.FloatTensor)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function is called. If not specified, the callback is called at + every step. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in + [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + + Examples: + + Returns: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned, + otherwise a `tuple` is returned where the first element is a list with the generated images and the + second element is a list of `bool`s indicating whether the corresponding generated image contains + "not-safe-for-work" (nsfw) content. + """ + # 1. Check inputs. Raise error if not correct + #print(self.scheduler) + #assert(0) + self.check_inputs( + prompt, height, width, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds + ) + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + # 3. Encode input prompt + text_encoder_lora_scale = ( + cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None + ) + + #print([prompt,device,num_images_per_prompt, do_classifier_free_guidance, prompt_embeds,negative_prompt_embeds, text_encoder_lora_scale]) + #assert(0) + prompt_embeds, negative_prompt_embeds, attention_mask, uncond_attention_mask = \ + self.encode_prompt(prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + lora_scale=text_encoder_lora_scale, + ) + prompt_embeds_t5, negative_prompt_embeds_t5, attention_mask_t5, uncond_attention_mask_t5 = \ + self.encode_prompt(prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt, + prompt_embeds=prompt_embeds_t5, + negative_prompt_embeds=negative_prompt_embeds_t5, + lora_scale=text_encoder_lora_scale, + embedder=self.embedder_t5, + ) + + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + if do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) + attention_mask = torch.cat([uncond_attention_mask, attention_mask]) + prompt_embeds_t5 = torch.cat([negative_prompt_embeds_t5, prompt_embeds_t5]) + attention_mask_t5 = torch.cat([uncond_attention_mask_t5, attention_mask_t5]) + + # 4. Prepare timesteps + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps = self.scheduler.timesteps + #print(prompt_embeds.dtype) + #print(device) + #print(generator) + #assert(0) + + # 6. Prepare latent variables + num_channels_latents = self.unet.config.in_channels + latents = self.prepare_latents(batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + ) + + + # Set the save path + #save_path = "/apdcephfs_cq8/share_1367250/xuhuaren/comfyui_project/comfyui_debug.pt" + + # Save the variables as a dictionary + """ + torch.save({ + "prompt_embeds": prompt_embeds, + "prompt_embeds_t5": prompt_embeds_t5, + "scheduler": self.scheduler, + "latents": latents + }, save_path) + assert(0) + """ + + # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 8. Denoising loop + #print(len(timesteps)) + #print(num_inference_steps) + #assert(0) + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + #with open("/apdcephfs_cq8/share_1367250/xuhuaren/dit-open/HunyuanDiT/output_python.txt", "a") as output_file: + # output_file.write(f"{t}\n") + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + # expand scalar t to 1-D tensor to match the 1st dim of latent_model_input + t_expand = torch.tensor([t] * latent_model_input.shape[0], device=latent_model_input.device) + + if use_fp16: + latent_model_input = latent_model_input.half() + t_expand = t_expand.half() + prompt_embeds = prompt_embeds.half() + ims = image_meta_size.half() if image_meta_size is not None else None + else: + ims = image_meta_size if image_meta_size is not None else None + + + #print(ims) + #assert(0) + + # predict the noise residual + if self.infer_mode in ["fa", "torch"]: + noise_pred = self.unet( + latent_model_input, + t_expand, + encoder_hidden_states=prompt_embeds, + text_embedding_mask=attention_mask, + encoder_hidden_states_t5=prompt_embeds_t5, + text_embedding_mask_t5=attention_mask_t5, + image_meta_size=ims, + style=style, + cos_cis_img=freqs_cis_img[0], + sin_cis_img=freqs_cis_img[1], + return_dict=False, + ) + elif self.infer_mode == "trt": + raise NotImplementedError("TensorRT model is not supported yet.") + else: + raise ValueError("[ERROR] invalid inference mode! please check your config file") + if learn_sigma: + noise_pred, _ = noise_pred.chunk(2, dim=1) + + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + if do_classifier_free_guidance and guidance_rescale > 0.0: + # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf + noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale) + + # compute the previous noisy sample x_t -> x_t-1 + results = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=True) + latents = results.prev_sample + pred_x0 = results.pred_original_sample if hasattr(results, 'pred_original_sample') else None + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + if callback is not None and i % callback_steps == 0: + callback(i, t, latents, pred_x0) + + if not output_type == "latent": + image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] + image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) + else: + image = latents + has_nsfw_concept = None + + if has_nsfw_concept is None: + do_denormalize = [True] * image.shape[0] + else: + do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept] + + image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize) + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (image, has_nsfw_concept) + + return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) diff --git a/comfyui-hydit/hydit/inference.py b/comfyui-hydit/hydit/inference.py new file mode 100644 index 0000000..1c20807 --- /dev/null +++ b/comfyui-hydit/hydit/inference.py @@ -0,0 +1,406 @@ +import random +import time +from pathlib import Path + +import numpy as np +import torch + +# For reproducibility +# torch.backends.cudnn.benchmark = False +# torch.backends.cudnn.deterministic = True + +from diffusers import schedulers +from diffusers.models import AutoencoderKL +from loguru import logger +from transformers import BertModel, BertTokenizer +from transformers.modeling_utils import logger as tf_logger + +from .constants import SAMPLER_FACTORY, NEGATIVE_PROMPT +from .diffusion.pipeline import StableDiffusionPipeline +from .modules.models import HunYuanDiT, HUNYUAN_DIT_CONFIG +from .modules.posemb_layers import get_2d_rotary_pos_embed, get_fill_resize_and_crop +from .modules.text_encoder import MT5Embedder +from .utils.tools import set_seeds + + +class Resolution: + def __init__(self, width, height): + self.width = width + self.height = height + + def __str__(self): + return f'{self.height}x{self.width}' + + +class ResolutionGroup: + def __init__(self): + self.data = [ + Resolution(768, 768), # 1:1 + Resolution(1024, 1024), # 1:1 + Resolution(1280, 1280), # 1:1 + Resolution(1024, 768), # 4:3 + Resolution(1152, 864), # 4:3 + Resolution(1280, 960), # 4:3 + Resolution(768, 1024), # 3:4 + Resolution(864, 1152), # 3:4 + Resolution(960, 1280), # 3:4 + Resolution(1280, 768), # 16:9 + Resolution(768, 1280), # 9:16 + ] + self.supported_sizes = set([(r.width, r.height) for r in self.data]) + + def is_valid(self, width, height): + return (width, height) in self.supported_sizes + + +STANDARD_RATIO = np.array([ + 1.0, # 1:1 + 4.0 / 3.0, # 4:3 + 3.0 / 4.0, # 3:4 + 16.0 / 9.0, # 16:9 + 9.0 / 16.0, # 9:16 +]) +STANDARD_SHAPE = [ + [(768, 768), (1024, 1024), (1280, 1280)], # 1:1 + [(1024, 768), (1152, 864), (1280, 960)], # 4:3 + [(768, 1024), (864, 1152), (960, 1280)], # 3:4 + [(1280, 768)], # 16:9 + [(768, 1280)], # 9:16 +] +STANDARD_AREA = [ + np.array([w * h for w, h in shapes]) + for shapes in STANDARD_SHAPE +] + + +def get_standard_shape(target_width, target_height): + """ + Map image size to standard size. + """ + target_ratio = target_width / target_height + closest_ratio_idx = np.argmin(np.abs(STANDARD_RATIO - target_ratio)) + closest_area_idx = np.argmin(np.abs(STANDARD_AREA[closest_ratio_idx] - target_width * target_height)) + width, height = STANDARD_SHAPE[closest_ratio_idx][closest_area_idx] + return width, height + + +def _to_tuple(val): + if isinstance(val, (list, tuple)): + if len(val) == 1: + val = [val[0], val[0]] + elif len(val) == 2: + val = tuple(val) + else: + raise ValueError(f"Invalid value: {val}") + elif isinstance(val, (int, float)): + val = (val, val) + else: + raise ValueError(f"Invalid value: {val}") + return val + + +def get_pipeline(args, vae, text_encoder, tokenizer, model, device, rank, + embedder_t5, infer_mode, sampler=None): + """ + Get scheduler and pipeline for sampling. The sampler and pipeline are both + based on diffusers and make some modifications. + + Returns + ------- + pipeline: StableDiffusionPipeline + sampler_name: str + """ + sampler = sampler or args.sampler + + # Load sampler from factory + kwargs = SAMPLER_FACTORY[sampler]['kwargs'] + scheduler = SAMPLER_FACTORY[sampler]['scheduler'] + + # Update sampler according to the arguments + kwargs['beta_schedule'] = args.noise_schedule + kwargs['beta_start'] = args.beta_start + kwargs['beta_end'] = args.beta_end + kwargs['prediction_type'] = args.predict_type + + # Build scheduler according to the sampler. + scheduler_class = getattr(schedulers, scheduler) + scheduler = scheduler_class(**kwargs) + #print(scheduler) + #assert(0) + + # Set timesteps for inference steps. + scheduler.set_timesteps(args.infer_steps, device) + + # Only enable progress bar for rank 0 + progress_bar_config = {} if rank == 0 else {'disable': True} + + pipeline = StableDiffusionPipeline(vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + unet=model, + scheduler=scheduler, + feature_extractor=None, + safety_checker=None, + requires_safety_checker=False, + progress_bar_config=progress_bar_config, + embedder_t5=embedder_t5, + infer_mode=infer_mode, + ) + + pipeline = pipeline.to(device) + + return pipeline, sampler + + +class End2End(object): + def __init__(self, args, models_root_path, MODEL_PATH = None, VAE_PATH = None): + self.args = args + + # Check arguments + t2i_root_path = Path(models_root_path) / "t2i" + self.root = t2i_root_path + logger.info(f"Got text-to-image model root path: {t2i_root_path}") + + # Set device and disable gradient + self.device = "cuda" if torch.cuda.is_available() else "cpu" + torch.set_grad_enabled(False) + # Disable BertModel logging checkpoint info + tf_logger.setLevel('ERROR') + + # ======================================================================== + model_dir = self.root / "model" + + # ======================================================================== + logger.info(f"Loading CLIP Text Encoder...") + text_encoder_path = self.root / "clip_text_encoder" + self.clip_text_encoder = BertModel.from_pretrained(str(text_encoder_path), False, revision=None).to(self.device) + logger.info(f"Loading CLIP Text Encoder finished") + #print(self.clip_text_encoder) + + # ======================================================================== + logger.info(f"Loading CLIP Tokenizer...") + tokenizer_path = self.root / "tokenizer" + self.tokenizer = BertTokenizer.from_pretrained(str(tokenizer_path)) + logger.info(f"Loading CLIP Tokenizer finished") + #print(self.tokenizer) + #assert(0) + + # ======================================================================== + logger.info(f"Loading T5 Text Encoder and T5 Tokenizer...") + t5_text_encoder_path = self.root / 'mt5' + embedder_t5 = MT5Embedder(t5_text_encoder_path, torch_dtype=torch.float16, max_length=256) + self.embedder_t5 = embedder_t5 + logger.info(f"Loading t5_text_encoder and t5_tokenizer finished") + + # ======================================================================== + logger.info(f"Loading VAE...") + if VAE_PATH: + vae_path = VAE_PATH + self.vae = AutoencoderKL.from_single_file(str(vae_path)).to(self.device) + else: + vae_path = self.root / "sdxl-vae-fp16-fix" + self.vae = AutoencoderKL.from_pretrained(str(vae_path)).to(self.device) + + + logger.info(f"Loading VAE finished") + + # ======================================================================== + # Create model structure and load the checkpoint + logger.info(f"Building HunYuan-DiT model...") + #print(self.args.model) + #print(self.args) + model_config = HUNYUAN_DIT_CONFIG[self.args.model] + self.patch_size = model_config['patch_size'] + self.head_size = model_config['hidden_size'] // model_config['num_heads'] + self.resolutions, self.freqs_cis_img = self.standard_shapes() # Used for TensorRT models + self.image_size = _to_tuple(self.args.image_size) + latent_size = (self.image_size[0] // 8, self.image_size[1] // 8) + + self.infer_mode = self.args.infer_mode + if self.infer_mode in ['fa', 'torch']: + if MODEL_PATH: + model_path = Path(MODEL_PATH) + else: + model_path = model_dir / f"pytorch_model_{self.args.load_key}.pt" + + if not model_path.exists(): + raise ValueError(f"model_path not exists: {model_path}") + # Build model structure + self.model = HunYuanDiT(self.args, + input_size=latent_size, + **model_config, + log_fn=logger.info, + ).half().to(self.device) # Force to use fp16 + # Load model checkpoint + logger.info(f"Loading model checkpoint {model_path}...") + state_dict = torch.load(model_path, map_location=lambda storage, loc: storage) + self.model.load_state_dict(state_dict) + self.model.eval() + elif self.infer_mode == 'trt': + raise NotImplementedError("TensorRT model is not supported yet.") + else: + raise ValueError(f"Unknown infer_mode: {self.infer_mode}") + + # ======================================================================== + # Build inference pipeline. We use a customized StableDiffusionPipeline. + logger.info(f"Loading inference pipeline...") + #self.pipeline, self.sampler = self.load_sampler() + logger.info(f'Loading pipeline finished') + + # ======================================================================== + self.default_negative_prompt = NEGATIVE_PROMPT + logger.info("==================================================") + logger.info(f" Model is ready. ") + logger.info("==================================================") + + def load_sampler(self, sampler=None): + pipeline, sampler = get_pipeline(self.args, + self.vae, + self.clip_text_encoder, + self.tokenizer, + self.model, + device=self.device, + rank=0, + embedder_t5=self.embedder_t5, + infer_mode=self.infer_mode, + sampler=sampler, + ) + return pipeline, sampler + + def calc_rope(self, height, width): + th = height // 8 // self.patch_size + tw = width // 8 // self.patch_size + base_size = 512 // 8 // self.patch_size + start, stop = get_fill_resize_and_crop((th, tw), base_size) + sub_args = [start, stop, (th, tw)] + rope = get_2d_rotary_pos_embed(self.head_size, *sub_args) + return rope + + def standard_shapes(self): + resolutions = ResolutionGroup() + freqs_cis_img = {} + for reso in resolutions.data: + freqs_cis_img[str(reso)] = self.calc_rope(reso.height, reso.width) + return resolutions, freqs_cis_img + + def predict(self, + user_prompt, + height=1024, + width=1024, + seed=None, + enhanced_prompt=None, + negative_prompt=None, + infer_steps=100, + guidance_scale=6, + batch_size=1, + src_size_cond=(1024, 1024), + sampler=None, + ): + # ======================================================================== + # Arguments: seed + # ======================================================================== + if seed is None: + seed = random.randint(0, 1_000_000) + if not isinstance(seed, int): + raise TypeError(f"`seed` must be an integer, but got {type(seed)}") + generator = set_seeds(seed) + + # ======================================================================== + # Arguments: target_width, target_height + # ======================================================================== + if width <= 0 or height <= 0: + raise ValueError(f"`height` and `width` must be positive integers, got height={height}, width={width}") + logger.info(f"Input (height, width) = ({height}, {width})") + if self.infer_mode in ['fa', 'torch']: + # We must force height and width to align to 16 and to be an integer. + target_height = int((height // 16) * 16) + target_width = int((width // 16) * 16) + logger.info(f"Align to 16: (height, width) = ({target_height}, {target_width})") + elif self.infer_mode == 'trt': + target_width, target_height = get_standard_shape(width, height) + logger.info(f"Align to standard shape: (height, width) = ({target_height}, {target_width})") + else: + raise ValueError(f"Unknown infer_mode: {self.infer_mode}") + + # ======================================================================== + # Arguments: prompt, new_prompt, negative_prompt + # ======================================================================== + if not isinstance(user_prompt, str): + raise TypeError(f"`user_prompt` must be a string, but got {type(user_prompt)}") + user_prompt = user_prompt.strip() + prompt = user_prompt + + if enhanced_prompt is not None: + if not isinstance(enhanced_prompt, str): + raise TypeError(f"`enhanced_prompt` must be a string, but got {type(enhanced_prompt)}") + enhanced_prompt = enhanced_prompt.strip() + prompt = enhanced_prompt + + # negative prompt + if negative_prompt is None or negative_prompt == '': + negative_prompt = self.default_negative_prompt + if not isinstance(negative_prompt, str): + raise TypeError(f"`negative_prompt` must be a string, but got {type(negative_prompt)}") + + # ======================================================================== + # Arguments: style. (A fixed argument. Don't Change it.) + # ======================================================================== + style = torch.as_tensor([0, 0] * batch_size, device=self.device) + + # ======================================================================== + # Inner arguments: image_meta_size (Please refer to SDXL.) + # ======================================================================== + if isinstance(src_size_cond, int): + src_size_cond = [src_size_cond, src_size_cond] + if not isinstance(src_size_cond, (list, tuple)): + raise TypeError(f"`src_size_cond` must be a list or tuple, but got {type(src_size_cond)}") + if len(src_size_cond) != 2: + raise ValueError(f"`src_size_cond` must be a tuple of 2 integers, but got {len(src_size_cond)}") + size_cond = list(src_size_cond) + [target_width, target_height, 0, 0] + image_meta_size = torch.as_tensor([size_cond] * 2 * batch_size, device=self.device) + + # ======================================================================== + start_time = time.time() + logger.debug(f""" + prompt: {user_prompt} + enhanced prompt: {enhanced_prompt} + seed: {seed} + (height, width): {(target_height, target_width)} + negative_prompt: {negative_prompt} + batch_size: {batch_size} + guidance_scale: {guidance_scale} + infer_steps: {infer_steps} + image_meta_size: {size_cond} + """) + reso = f'{target_height}x{target_width}' + if reso in self.freqs_cis_img: + freqs_cis_img = self.freqs_cis_img[reso] + else: + freqs_cis_img = self.calc_rope(target_height, target_width) + + #if sampler is not None and sampler != self.sampler: + # self.pipeline, self.sampler = self.load_sampler(sampler) + + samples = self.pipeline( + height=target_height, + width=target_width, + prompt=prompt, + negative_prompt=negative_prompt, + num_images_per_prompt=batch_size, + guidance_scale=guidance_scale, + num_inference_steps=infer_steps, + image_meta_size=image_meta_size, + style=style, + return_dict=False, + generator=generator, + freqs_cis_img=freqs_cis_img, + use_fp16=self.args.use_fp16, + learn_sigma=self.args.learn_sigma, + )[0] + gen_time = time.time() - start_time + logger.debug(f"Success, time: {gen_time}") + + return { + 'images': samples, + 'seed': seed, + } diff --git a/comfyui-hydit/hydit/modules/__init__.py b/comfyui-hydit/hydit/modules/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/comfyui-hydit/hydit/modules/attn_layers.py b/comfyui-hydit/hydit/modules/attn_layers.py new file mode 100644 index 0000000..4308af9 --- /dev/null +++ b/comfyui-hydit/hydit/modules/attn_layers.py @@ -0,0 +1,377 @@ +import torch +import torch.nn as nn +from typing import Tuple, Union, Optional + +try: + import flash_attn + if hasattr(flash_attn, '__version__') and int(flash_attn.__version__[0]) == 2: + from flash_attn.flash_attn_interface import flash_attn_kvpacked_func + from flash_attn.modules.mha import FlashSelfAttention, FlashCrossAttention + else: + from flash_attn.flash_attn_interface import flash_attn_unpadded_kvpacked_func + from flash_attn.modules.mha import FlashSelfAttention, FlashCrossAttention +except Exception as e: + print(f'flash_attn import failed: {e}') + + +def reshape_for_broadcast(freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]], x: torch.Tensor, head_first=False): + """ + Reshape frequency tensor for broadcasting it with another tensor. + + This function reshapes the frequency tensor to have the same shape as the target tensor 'x' + for the purpose of broadcasting the frequency tensor during element-wise operations. + + Args: + freqs_cis (Union[torch.Tensor, Tuple[torch.Tensor]]): Frequency tensor to be reshaped. + x (torch.Tensor): Target tensor for broadcasting compatibility. + head_first (bool): head dimension first (except batch dim) or not. + + Returns: + torch.Tensor: Reshaped frequency tensor. + + Raises: + AssertionError: If the frequency tensor doesn't match the expected shape. + AssertionError: If the target tensor 'x' doesn't have the expected number of dimensions. + """ + ndim = x.ndim + assert 0 <= 1 < ndim + + if isinstance(freqs_cis, tuple): + # freqs_cis: (cos, sin) in real space + if head_first: + assert freqs_cis[0].shape == (x.shape[-2], x.shape[-1]), f'freqs_cis shape {freqs_cis[0].shape} does not match x shape {x.shape}' + shape = [d if i == ndim - 2 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)] + else: + assert freqs_cis[0].shape == (x.shape[1], x.shape[-1]), f'freqs_cis shape {freqs_cis[0].shape} does not match x shape {x.shape}' + shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)] + return freqs_cis[0].view(*shape), freqs_cis[1].view(*shape) + else: + # freqs_cis: values in complex space + if head_first: + assert freqs_cis.shape == (x.shape[-2], x.shape[-1]), f'freqs_cis shape {freqs_cis.shape} does not match x shape {x.shape}' + shape = [d if i == ndim - 2 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)] + else: + assert freqs_cis.shape == (x.shape[1], x.shape[-1]), f'freqs_cis shape {freqs_cis.shape} does not match x shape {x.shape}' + shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)] + return freqs_cis.view(*shape) + + +def rotate_half(x): + x_real, x_imag = x.float().reshape(*x.shape[:-1], -1, 2).unbind(-1) # [B, S, H, D//2] + return torch.stack([-x_imag, x_real], dim=-1).flatten(3) + + +def apply_rotary_emb( + xq: torch.Tensor, + xk: Optional[torch.Tensor], + freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]], + head_first: bool = False, +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Apply rotary embeddings to input tensors using the given frequency tensor. + + This function applies rotary embeddings to the given query 'xq' and key 'xk' tensors using the provided + frequency tensor 'freqs_cis'. The input tensors are reshaped as complex numbers, and the frequency tensor + is reshaped for broadcasting compatibility. The resulting tensors contain rotary embeddings and are + returned as real tensors. + + Args: + xq (torch.Tensor): Query tensor to apply rotary embeddings. [B, S, H, D] + xk (torch.Tensor): Key tensor to apply rotary embeddings. [B, S, H, D] + freqs_cis (Union[torch.Tensor, Tuple[torch.Tensor]]): Precomputed frequency tensor for complex exponentials. + head_first (bool): head dimension first (except batch dim) or not. + + Returns: + Tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor and key tensor with rotary embeddings. + + """ + xk_out = None + if isinstance(freqs_cis, tuple): + cos, sin = reshape_for_broadcast(freqs_cis, xq, head_first) # [S, D] + cos, sin = cos.to(xq.device), sin.to(xq.device) + xq_out = (xq.float() * cos + rotate_half(xq.float()) * sin).type_as(xq) + if xk is not None: + xk_out = (xk.float() * cos + rotate_half(xk.float()) * sin).type_as(xk) + else: + xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) # [B, S, H, D//2] + freqs_cis = reshape_for_broadcast(freqs_cis, xq_, head_first).to(xq.device) # [S, D//2] --> [1, S, 1, D//2] + xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3).type_as(xq) + if xk is not None: + xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2)) # [B, S, H, D//2] + xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3).type_as(xk) + + return xq_out, xk_out + + +class FlashSelfMHAModified(nn.Module): + """ + Use QK Normalization. + """ + def __init__(self, + dim, + num_heads, + qkv_bias=True, + qk_norm=False, + attn_drop=0.0, + proj_drop=0.0, + device=None, + dtype=None, + norm_layer=nn.LayerNorm, + ): + factory_kwargs = {'device': device, 'dtype': dtype} + super().__init__() + self.dim = dim + self.num_heads = num_heads + assert self.dim % num_heads == 0, "self.kdim must be divisible by num_heads" + self.head_dim = self.dim // num_heads + assert self.head_dim % 8 == 0 and self.head_dim <= 128, "Only support head_dim <= 128 and divisible by 8" + + self.Wqkv = nn.Linear(dim, 3 * dim, bias=qkv_bias, **factory_kwargs) + # TODO: eps should be 1 / 65530 if using fp16 + self.q_norm = norm_layer(self.head_dim, elementwise_affine=True, eps=1e-6) if qk_norm else nn.Identity() + self.k_norm = norm_layer(self.head_dim, elementwise_affine=True, eps=1e-6) if qk_norm else nn.Identity() + self.inner_attn = FlashSelfAttention(attention_dropout=attn_drop) + self.out_proj = nn.Linear(dim, dim, bias=qkv_bias, **factory_kwargs) + self.proj_drop = nn.Dropout(proj_drop) + + def forward(self, x, freqs_cis_img=None): + """ + Parameters + ---------- + x: torch.Tensor + (batch, seqlen, hidden_dim) (where hidden_dim = num heads * head dim) + freqs_cis_img: torch.Tensor + (batch, hidden_dim // 2), RoPE for image + """ + b, s, d = x.shape + + qkv = self.Wqkv(x) + qkv = qkv.view(b, s, 3, self.num_heads, self.head_dim) # [b, s, 3, h, d] + q, k, v = qkv.unbind(dim=2) # [b, s, h, d] + q = self.q_norm(q).half() # [b, s, h, d] + k = self.k_norm(k).half() + + # Apply RoPE if needed + if freqs_cis_img is not None: + qq, kk = apply_rotary_emb(q, k, freqs_cis_img) + assert qq.shape == q.shape and kk.shape == k.shape, f'qq: {qq.shape}, q: {q.shape}, kk: {kk.shape}, k: {k.shape}' + q, k = qq, kk + + qkv = torch.stack([q, k, v], dim=2) # [b, s, 3, h, d] + context = self.inner_attn(qkv) + out = self.out_proj(context.view(b, s, d)) + out = self.proj_drop(out) + + out_tuple = (out,) + + return out_tuple + + +class FlashCrossMHAModified(nn.Module): + """ + Use QK Normalization. + """ + def __init__(self, + qdim, + kdim, + num_heads, + qkv_bias=True, + qk_norm=False, + attn_drop=0.0, + proj_drop=0.0, + device=None, + dtype=None, + norm_layer=nn.LayerNorm, + ): + factory_kwargs = {'device': device, 'dtype': dtype} + super().__init__() + self.qdim = qdim + self.kdim = kdim + self.num_heads = num_heads + assert self.qdim % num_heads == 0, "self.qdim must be divisible by num_heads" + self.head_dim = self.qdim // num_heads + assert self.head_dim % 8 == 0 and self.head_dim <= 128, "Only support head_dim <= 128 and divisible by 8" + + self.scale = self.head_dim ** -0.5 + + self.q_proj = nn.Linear(qdim, qdim, bias=qkv_bias, **factory_kwargs) + self.kv_proj = nn.Linear(kdim, 2 * qdim, bias=qkv_bias, **factory_kwargs) + + # TODO: eps should be 1 / 65530 if using fp16 + self.q_norm = norm_layer(self.head_dim, elementwise_affine=True, eps=1e-6) if qk_norm else nn.Identity() + self.k_norm = norm_layer(self.head_dim, elementwise_affine=True, eps=1e-6) if qk_norm else nn.Identity() + + self.inner_attn = FlashCrossAttention(attention_dropout=attn_drop) + self.out_proj = nn.Linear(qdim, qdim, bias=qkv_bias, **factory_kwargs) + self.proj_drop = nn.Dropout(proj_drop) + + def forward(self, x, y, freqs_cis_img=None): + """ + Parameters + ---------- + x: torch.Tensor + (batch, seqlen1, hidden_dim) (where hidden_dim = num_heads * head_dim) + y: torch.Tensor + (batch, seqlen2, hidden_dim2) + freqs_cis_img: torch.Tensor + (batch, hidden_dim // num_heads), RoPE for image + """ + b, s1, _ = x.shape # [b, s1, D] + _, s2, _ = y.shape # [b, s2, 1024] + + q = self.q_proj(x).view(b, s1, self.num_heads, self.head_dim) # [b, s1, h, d] + kv = self.kv_proj(y).view(b, s2, 2, self.num_heads, self.head_dim) # [b, s2, 2, h, d] + k, v = kv.unbind(dim=2) # [b, s2, h, d] + q = self.q_norm(q).half() # [b, s1, h, d] + k = self.k_norm(k).half() # [b, s2, h, d] + + # Apply RoPE if needed + if freqs_cis_img is not None: + qq, _ = apply_rotary_emb(q, None, freqs_cis_img) + assert qq.shape == q.shape, f'qq: {qq.shape}, q: {q.shape}' + q = qq # [b, s1, h, d] + kv = torch.stack([k, v], dim=2) # [b, s1, 2, h, d] + context = self.inner_attn(q, kv) # [b, s1, h, d] + context = context.view(b, s1, -1) # [b, s1, D] + + out = self.out_proj(context) + out = self.proj_drop(out) + + out_tuple = (out,) + + return out_tuple + + +class CrossAttention(nn.Module): + """ + Use QK Normalization. + """ + def __init__(self, + qdim, + kdim, + num_heads, + qkv_bias=True, + qk_norm=False, + attn_drop=0.0, + proj_drop=0.0, + device=None, + dtype=None, + norm_layer=nn.LayerNorm, + ): + factory_kwargs = {'device': device, 'dtype': dtype} + super().__init__() + self.qdim = qdim + self.kdim = kdim + self.num_heads = num_heads + assert self.qdim % num_heads == 0, "self.qdim must be divisible by num_heads" + self.head_dim = self.qdim // num_heads + assert self.head_dim % 8 == 0 and self.head_dim <= 128, "Only support head_dim <= 128 and divisible by 8" + self.scale = self.head_dim ** -0.5 + + self.q_proj = nn.Linear(qdim, qdim, bias=qkv_bias, **factory_kwargs) + self.kv_proj = nn.Linear(kdim, 2 * qdim, bias=qkv_bias, **factory_kwargs) + + # TODO: eps should be 1 / 65530 if using fp16 + self.q_norm = norm_layer(self.head_dim, elementwise_affine=True, eps=1e-6) if qk_norm else nn.Identity() + self.k_norm = norm_layer(self.head_dim, elementwise_affine=True, eps=1e-6) if qk_norm else nn.Identity() + self.attn_drop = nn.Dropout(attn_drop) + self.out_proj = nn.Linear(qdim, qdim, bias=qkv_bias, **factory_kwargs) + self.proj_drop = nn.Dropout(proj_drop) + + def forward(self, x, y, freqs_cis_img=None): + """ + Parameters + ---------- + x: torch.Tensor + (batch, seqlen1, hidden_dim) (where hidden_dim = num heads * head dim) + y: torch.Tensor + (batch, seqlen2, hidden_dim2) + freqs_cis_img: torch.Tensor + (batch, hidden_dim // 2), RoPE for image + """ + b, s1, c = x.shape # [b, s1, D] + _, s2, c = y.shape # [b, s2, 1024] + + q = self.q_proj(x).view(b, s1, self.num_heads, self.head_dim) # [b, s1, h, d] + kv = self.kv_proj(y).view(b, s2, 2, self.num_heads, self.head_dim) # [b, s2, 2, h, d] + k, v = kv.unbind(dim=2) # [b, s, h, d] + q = self.q_norm(q) + k = self.k_norm(k) + + # Apply RoPE if needed + if freqs_cis_img is not None: + qq, _ = apply_rotary_emb(q, None, freqs_cis_img) + assert qq.shape == q.shape, f'qq: {qq.shape}, q: {q.shape}' + q = qq + + q = q * self.scale + q = q.transpose(-2, -3).contiguous() # q -> B, L1, H, C - B, H, L1, C + k = k.permute(0, 2, 3, 1).contiguous() # k -> B, L2, H, C - B, H, C, L2 + attn = q @ k # attn -> B, H, L1, L2 + attn = attn.softmax(dim=-1) # attn -> B, H, L1, L2 + attn = self.attn_drop(attn) + x = attn @ v.transpose(-2, -3) # v -> B, L2, H, C - B, H, L2, C x-> B, H, L1, C + context = x.transpose(1, 2) # context -> B, H, L1, C - B, L1, H, C + + context = context.contiguous().view(b, s1, -1) + + out = self.out_proj(context) # context.reshape - B, L1, -1 + out = self.proj_drop(out) + + out_tuple = (out,) + + return out_tuple + + +class Attention(nn.Module): + """ + We rename some layer names to align with flash attention + """ + def __init__(self, dim, num_heads, qkv_bias=True, qk_norm=False, attn_drop=0., proj_drop=0., + norm_layer=nn.LayerNorm, + ): + super().__init__() + self.dim = dim + self.num_heads = num_heads + assert self.dim % num_heads == 0, 'dim should be divisible by num_heads' + self.head_dim = self.dim // num_heads + # This assertion is aligned with flash attention + assert self.head_dim % 8 == 0 and self.head_dim <= 128, "Only support head_dim <= 128 and divisible by 8" + self.scale = self.head_dim ** -0.5 + + # qkv --> Wqkv + self.Wqkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + # TODO: eps should be 1 / 65530 if using fp16 + self.q_norm = norm_layer(self.head_dim, elementwise_affine=True, eps=1e-6) if qk_norm else nn.Identity() + self.k_norm = norm_layer(self.head_dim, elementwise_affine=True, eps=1e-6) if qk_norm else nn.Identity() + self.attn_drop = nn.Dropout(attn_drop) + self.out_proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + def forward(self, x, freqs_cis_img=None): + B, N, C = x.shape + qkv = self.Wqkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4) # [3, b, h, s, d] + q, k, v = qkv.unbind(0) # [b, h, s, d] + q = self.q_norm(q) # [b, h, s, d] + k = self.k_norm(k) # [b, h, s, d] + + # Apply RoPE if needed + if freqs_cis_img is not None: + qq, kk = apply_rotary_emb(q, k, freqs_cis_img, head_first=True) + assert qq.shape == q.shape and kk.shape == k.shape, \ + f'qq: {qq.shape}, q: {q.shape}, kk: {kk.shape}, k: {k.shape}' + q, k = qq, kk + + q = q * self.scale + attn = q @ k.transpose(-2, -1) # [b, h, s, d] @ [b, h, d, s] + attn = attn.softmax(dim=-1) # [b, h, s, s] + attn = self.attn_drop(attn) + x = attn @ v # [b, h, s, d] + + x = x.transpose(1, 2).reshape(B, N, C) # [b, s, h, d] + x = self.out_proj(x) + x = self.proj_drop(x) + + out_tuple = (x,) + + return out_tuple diff --git a/comfyui-hydit/hydit/modules/embedders.py b/comfyui-hydit/hydit/modules/embedders.py new file mode 100644 index 0000000..9fe08cb --- /dev/null +++ b/comfyui-hydit/hydit/modules/embedders.py @@ -0,0 +1,111 @@ +import math +import torch +import torch.nn as nn +from einops import repeat + +from timm.models.layers import to_2tuple + + +class PatchEmbed(nn.Module): + """ 2D Image to Patch Embedding + + Image to Patch Embedding using Conv2d + + A convolution based approach to patchifying a 2D image w/ embedding projection. + + Based on the impl in https://github.com/google-research/vision_transformer + + Hacked together by / Copyright 2020 Ross Wightman + + Remove the _assert function in forward function to be compatible with multi-resolution images. + """ + def __init__( + self, + img_size=224, + patch_size=16, + in_chans=3, + embed_dim=768, + norm_layer=None, + flatten=True, + bias=True, + ): + super().__init__() + if isinstance(img_size, int): + img_size = to_2tuple(img_size) + elif isinstance(img_size, (tuple, list)) and len(img_size) == 2: + img_size = tuple(img_size) + else: + raise ValueError(f"img_size must be int or tuple/list of length 2. Got {img_size}") + patch_size = to_2tuple(patch_size) + self.img_size = img_size + self.patch_size = patch_size + self.grid_size = (img_size[0] // patch_size[0], img_size[1] // patch_size[1]) + self.num_patches = self.grid_size[0] * self.grid_size[1] + self.flatten = flatten + + self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size, bias=bias) + self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity() + + def update_image_size(self, img_size): + self.img_size = img_size + self.grid_size = (img_size[0] // self.patch_size[0], img_size[1] // self.patch_size[1]) + self.num_patches = self.grid_size[0] * self.grid_size[1] + + def forward(self, x): + # B, C, H, W = x.shape + # _assert(H == self.img_size[0], f"Input image height ({H}) doesn't match model ({self.img_size[0]}).") + # _assert(W == self.img_size[1], f"Input image width ({W}) doesn't match model ({self.img_size[1]}).") + x = self.proj(x) + if self.flatten: + x = x.flatten(2).transpose(1, 2) # BCHW -> BNC + x = self.norm(x) + return x + + +def timestep_embedding(t, dim, max_period=10000, repeat_only=False): + """ + Create sinusoidal timestep embeddings. + :param t: a 1-D Tensor of N indices, one per batch element. + These may be fractional. + :param dim: the dimension of the output. + :param max_period: controls the minimum frequency of the embeddings. + :return: an (N, D) Tensor of positional embeddings. + """ + # https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py + if not repeat_only: + half = dim // 2 + freqs = torch.exp( + -math.log(max_period) + * torch.arange(start=0, end=half, dtype=torch.float32) + / half + ).to(device=t.device) # size: [dim/2], 一个指数衰减的曲线 + args = t[:, None].float() * freqs[None] + embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) + if dim % 2: + embedding = torch.cat( + [embedding, torch.zeros_like(embedding[:, :1])], dim=-1 + ) + else: + embedding = repeat(t, "b -> b d", d=dim) + return embedding + + +class TimestepEmbedder(nn.Module): + """ + Embeds scalar timesteps into vector representations. + """ + def __init__(self, hidden_size, frequency_embedding_size=256, out_size=None): + super().__init__() + if out_size is None: + out_size = hidden_size + self.mlp = nn.Sequential( + nn.Linear(frequency_embedding_size, hidden_size, bias=True), + nn.SiLU(), + nn.Linear(hidden_size, out_size, bias=True), + ) + self.frequency_embedding_size = frequency_embedding_size + + def forward(self, t): + t_freq = timestep_embedding(t, self.frequency_embedding_size).type(self.mlp[0].weight.dtype) + t_emb = self.mlp(t_freq) + return t_emb diff --git a/comfyui-hydit/hydit/modules/models.py b/comfyui-hydit/hydit/modules/models.py new file mode 100644 index 0000000..d125aa9 --- /dev/null +++ b/comfyui-hydit/hydit/modules/models.py @@ -0,0 +1,409 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from diffusers.configuration_utils import ConfigMixin, register_to_config +from diffusers.models import ModelMixin +from timm.models.vision_transformer import Mlp + +from .attn_layers import Attention, FlashCrossMHAModified, FlashSelfMHAModified, CrossAttention +from .embedders import TimestepEmbedder, PatchEmbed, timestep_embedding +from .norm_layers import RMSNorm +from .poolers import AttentionPool + + +def modulate(x, shift, scale): + return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1) + + +class FP32_Layernorm(nn.LayerNorm): + def forward(self, inputs: torch.Tensor) -> torch.Tensor: + origin_dtype = inputs.dtype + return F.layer_norm(inputs.float(), self.normalized_shape, self.weight.float(), self.bias.float(), + self.eps).to(origin_dtype) + + +class FP32_SiLU(nn.SiLU): + def forward(self, inputs: torch.Tensor) -> torch.Tensor: + return torch.nn.functional.silu(inputs.float(), inplace=False).to(inputs.dtype) + + +class HunYuanDiTBlock(nn.Module): + """ + A HunYuanDiT block with `add` conditioning. + """ + def __init__(self, + hidden_size, + c_emb_size, + num_heads, + mlp_ratio=4.0, + text_states_dim=1024, + use_flash_attn=False, + qk_norm=False, + norm_type="layer", + skip=False, + ): + super().__init__() + self.use_flash_attn = use_flash_attn + use_ele_affine = True + + if norm_type == "layer": + norm_layer = FP32_Layernorm + elif norm_type == "rms": + norm_layer = RMSNorm + else: + raise ValueError(f"Unknown norm_type: {norm_type}") + + # ========================= Self-Attention ========================= + self.norm1 = norm_layer(hidden_size, elementwise_affine=use_ele_affine, eps=1e-6) + if use_flash_attn: + self.attn1 = FlashSelfMHAModified(hidden_size, num_heads=num_heads, qkv_bias=True, qk_norm=qk_norm) + else: + self.attn1 = Attention(hidden_size, num_heads=num_heads, qkv_bias=True, qk_norm=qk_norm) + + # ========================= FFN ========================= + self.norm2 = norm_layer(hidden_size, elementwise_affine=use_ele_affine, eps=1e-6) + mlp_hidden_dim = int(hidden_size * mlp_ratio) + approx_gelu = lambda: nn.GELU(approximate="tanh") + self.mlp = Mlp(in_features=hidden_size, hidden_features=mlp_hidden_dim, act_layer=approx_gelu, drop=0) + + # ========================= Add ========================= + # Simply use add like SDXL. + self.default_modulation = nn.Sequential( + FP32_SiLU(), + nn.Linear(c_emb_size, hidden_size, bias=True) + ) + + # ========================= Cross-Attention ========================= + if use_flash_attn: + self.attn2 = FlashCrossMHAModified(hidden_size, text_states_dim, num_heads=num_heads, qkv_bias=True, + qk_norm=qk_norm) + else: + self.attn2 = CrossAttention(hidden_size, text_states_dim, num_heads=num_heads, qkv_bias=True, + qk_norm=qk_norm) + self.norm3 = norm_layer(hidden_size, elementwise_affine=True, eps=1e-6) + + # ========================= Skip Connection ========================= + if skip: + self.skip_norm = norm_layer(2 * hidden_size, elementwise_affine=True, eps=1e-6) + self.skip_linear = nn.Linear(2 * hidden_size, hidden_size) + else: + self.skip_linear = None + + def forward(self, x, c=None, text_states=None, freq_cis_img=None, skip=None): + # Long Skip Connection + if self.skip_linear is not None: + cat = torch.cat([x, skip], dim=-1) + cat = self.skip_norm(cat) + x = self.skip_linear(cat) + + # Self-Attention + shift_msa = self.default_modulation(c).unsqueeze(dim=1) + attn_inputs = ( + self.norm1(x) + shift_msa, freq_cis_img, + ) + x = x + self.attn1(*attn_inputs)[0] + + # Cross-Attention + cross_inputs = ( + self.norm3(x), text_states, freq_cis_img + ) + x = x + self.attn2(*cross_inputs)[0] + + # FFN Layer + mlp_inputs = self.norm2(x) + x = x + self.mlp(mlp_inputs) + + return x + + +class FinalLayer(nn.Module): + """ + The final layer of HunYuanDiT. + """ + def __init__(self, final_hidden_size, c_emb_size, patch_size, out_channels): + super().__init__() + self.norm_final = nn.LayerNorm(final_hidden_size, elementwise_affine=False, eps=1e-6) + self.linear = nn.Linear(final_hidden_size, patch_size * patch_size * out_channels, bias=True) + self.adaLN_modulation = nn.Sequential( + FP32_SiLU(), + nn.Linear(c_emb_size, 2 * final_hidden_size, bias=True) + ) + + def forward(self, x, c): + shift, scale = self.adaLN_modulation(c).chunk(2, dim=1) + x = modulate(self.norm_final(x), shift, scale) + x = self.linear(x) + return x + + +class HunYuanDiT(ModelMixin, ConfigMixin): + """ + HunYuanDiT: Diffusion model with a Transformer backbone. + + Inherit ModelMixin and ConfigMixin to be compatible with the sampler StableDiffusionPipeline of diffusers. + + Parameters + ---------- + args: argparse.Namespace + The arguments parsed by argparse. + input_size: tuple + The size of the input image. + patch_size: int + The size of the patch. + in_channels: int + The number of input channels. + hidden_size: int + The hidden size of the transformer backbone. + depth: int + The number of transformer blocks. + num_heads: int + The number of attention heads. + mlp_ratio: float + The ratio of the hidden size of the MLP in the transformer block. + log_fn: callable + The logging function. + """ + @register_to_config + def __init__( + self, args, + input_size=(32, 32), + patch_size=2, + in_channels=4, + hidden_size=1152, + depth=28, + num_heads=16, + mlp_ratio=4.0, + log_fn=print, + ): + super().__init__() + self.args = args + self.log_fn = log_fn + self.depth = depth + self.learn_sigma = args.learn_sigma + self.in_channels = in_channels + self.out_channels = in_channels * 2 if args.learn_sigma else in_channels + self.patch_size = patch_size + self.num_heads = num_heads + self.hidden_size = hidden_size + self.text_states_dim = args.text_states_dim + self.text_states_dim_t5 = args.text_states_dim_t5 + self.text_len = args.text_len + self.text_len_t5 = args.text_len_t5 + self.norm = args.norm + + use_flash_attn = args.infer_mode == 'fa' + if use_flash_attn: + log_fn(f" Enable Flash Attention.") + qk_norm = True # See http://arxiv.org/abs/2302.05442 for details. + + self.mlp_t5 = nn.Sequential( + nn.Linear(self.text_states_dim_t5, self.text_states_dim_t5 * 4, bias=True), + FP32_SiLU(), + nn.Linear(self.text_states_dim_t5 * 4, self.text_states_dim, bias=True), + ) + # learnable replace + self.text_embedding_padding = nn.Parameter( + torch.randn(self.text_len + self.text_len_t5, self.text_states_dim, dtype=torch.float32)) + + # Attention pooling + self.pooler = AttentionPool(self.text_len_t5, self.text_states_dim_t5, num_heads=8, output_dim=1024) + + # Here we use a default learned embedder layer for future extension. + self.style_embedder = nn.Embedding(1, hidden_size) + + # Image size and crop size conditions + self.extra_in_dim = 256 * 6 + hidden_size + + # Text embedding for `add` + self.x_embedder = PatchEmbed(input_size, patch_size, in_channels, hidden_size) + self.t_embedder = TimestepEmbedder(hidden_size) + self.extra_in_dim += 1024 + self.extra_embedder = nn.Sequential( + nn.Linear(self.extra_in_dim, hidden_size * 4), + FP32_SiLU(), + nn.Linear(hidden_size * 4, hidden_size, bias=True), + ) + + # Image embedding + num_patches = self.x_embedder.num_patches + log_fn(f" Number of tokens: {num_patches}") + + # HUnYuanDiT Blocks + self.blocks = nn.ModuleList([ + HunYuanDiTBlock(hidden_size=hidden_size, + c_emb_size=hidden_size, + num_heads=num_heads, + mlp_ratio=mlp_ratio, + text_states_dim=self.text_states_dim, + use_flash_attn=use_flash_attn, + qk_norm=qk_norm, + norm_type=self.norm, + skip=layer > depth // 2, + ) + for layer in range(depth) + ]) + + self.final_layer = FinalLayer(hidden_size, hidden_size, patch_size, self.out_channels) + self.unpatchify_channels = self.out_channels + + self.initialize_weights() + + def forward(self, + x, + t, + encoder_hidden_states=None, + text_embedding_mask=None, + encoder_hidden_states_t5=None, + text_embedding_mask_t5=None, + image_meta_size=None, + style=None, + cos_cis_img=None, + sin_cis_img=None, + return_dict=True, + ): + """ + Forward pass of the encoder. + + Parameters + ---------- + x: torch.Tensor + (B, D, H, W) + t: torch.Tensor + (B) + encoder_hidden_states: torch.Tensor + CLIP text embedding, (B, L_clip, D) + text_embedding_mask: torch.Tensor + CLIP text embedding mask, (B, L_clip) + encoder_hidden_states_t5: torch.Tensor + T5 text embedding, (B, L_t5, D) + text_embedding_mask_t5: torch.Tensor + T5 text embedding mask, (B, L_t5) + image_meta_size: torch.Tensor + (B, 6) + style: torch.Tensor + (B) + cos_cis_img: torch.Tensor + sin_cis_img: torch.Tensor + return_dict: bool + Whether to return a dictionary. + """ + + text_states = encoder_hidden_states # 2,77,1024 + text_states_t5 = encoder_hidden_states_t5 # 2,256,2048 + text_states_mask = text_embedding_mask.bool() # 2,77 + text_states_t5_mask = text_embedding_mask_t5.bool() # 2,256 + b_t5, l_t5, c_t5 = text_states_t5.shape + text_states_t5 = self.mlp_t5(text_states_t5.view(-1, c_t5)) + text_states = torch.cat([text_states, text_states_t5.view(b_t5, l_t5, -1)], dim=1) # 2,205,1024 + clip_t5_mask = torch.cat([text_states_mask, text_states_t5_mask], dim=-1) + + clip_t5_mask = clip_t5_mask + text_states = torch.where(clip_t5_mask.unsqueeze(2), text_states, self.text_embedding_padding.to(text_states)) + + _, _, oh, ow = x.shape + th, tw = oh // self.patch_size, ow // self.patch_size + + # ========================= Build time and image embedding ========================= + t = self.t_embedder(t) + x = self.x_embedder(x) + + # Get image RoPE embedding according to `reso`lution. + freqs_cis_img = (cos_cis_img, sin_cis_img) + + # ========================= Concatenate all extra vectors ========================= + # Build text tokens with pooling + extra_vec = self.pooler(encoder_hidden_states_t5) + + # Build image meta size tokens + image_meta_size = timestep_embedding(image_meta_size.view(-1), 256) # [B * 6, 256] + if self.args.use_fp16: + image_meta_size = image_meta_size.half() + image_meta_size = image_meta_size.view(-1, 6 * 256) + extra_vec = torch.cat([extra_vec, image_meta_size], dim=1) # [B, D + 6 * 256] + + # Build style tokens + style_embedding = self.style_embedder(style) + extra_vec = torch.cat([extra_vec, style_embedding], dim=1) + + # Concatenate all extra vectors + c = t + self.extra_embedder(extra_vec) # [B, D] + + # ========================= Forward pass through HunYuanDiT blocks ========================= + skips = [] + for layer, block in enumerate(self.blocks): + if layer > self.depth // 2: + skip = skips.pop() + x = block(x, c, text_states, freqs_cis_img, skip) # (N, L, D) + else: + x = block(x, c, text_states, freqs_cis_img) # (N, L, D) + + if layer < (self.depth // 2 - 1): + skips.append(x) + + # ========================= Final layer ========================= + x = self.final_layer(x, c) # (N, L, patch_size ** 2 * out_channels) + x = self.unpatchify(x, th, tw) # (N, out_channels, H, W) + + if return_dict: + return {'x': x} + return x + + def initialize_weights(self): + # Initialize transformer layers: + def _basic_init(module): + if isinstance(module, nn.Linear): + torch.nn.init.xavier_uniform_(module.weight) + if module.bias is not None: + nn.init.constant_(module.bias, 0) + self.apply(_basic_init) + + # Initialize patch_embed like nn.Linear (instead of nn.Conv2d): + w = self.x_embedder.proj.weight.data + nn.init.xavier_uniform_(w.view([w.shape[0], -1])) + nn.init.constant_(self.x_embedder.proj.bias, 0) + + # Initialize label embedding table: + nn.init.normal_(self.extra_embedder[0].weight, std=0.02) + nn.init.normal_(self.extra_embedder[2].weight, std=0.02) + + # Initialize timestep embedding MLP: + nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02) + nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02) + + # Zero-out adaLN modulation layers in HunYuanDiT blocks: + for block in self.blocks: + nn.init.constant_(block.default_modulation[-1].weight, 0) + nn.init.constant_(block.default_modulation[-1].bias, 0) + + # Zero-out output layers: + nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0) + nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0) + nn.init.constant_(self.final_layer.linear.weight, 0) + nn.init.constant_(self.final_layer.linear.bias, 0) + + def unpatchify(self, x, h, w): + """ + x: (N, T, patch_size**2 * C) + imgs: (N, H, W, C) + """ + c = self.unpatchify_channels + p = self.x_embedder.patch_size[0] + # h = w = int(x.shape[1] ** 0.5) + assert h * w == x.shape[1] + + x = x.reshape(shape=(x.shape[0], h, w, p, p, c)) + x = torch.einsum('nhwpqc->nchpwq', x) + imgs = x.reshape(shape=(x.shape[0], c, h * p, w * p)) + return imgs + + +################################################################################# +# HunYuanDiT Configs # +################################################################################# + +HUNYUAN_DIT_CONFIG = { + 'DiT-g/2': {'depth': 40, 'hidden_size': 1408, 'patch_size': 2, 'num_heads': 16, 'mlp_ratio': 4.3637}, + 'DiT-XL/2': {'depth': 28, 'hidden_size': 1152, 'patch_size': 2, 'num_heads': 16}, + 'DiT-L/2': {'depth': 24, 'hidden_size': 1024, 'patch_size': 2, 'num_heads': 16}, + 'DiT-B/2': {'depth': 12, 'hidden_size': 768, 'patch_size': 2, 'num_heads': 12}, +} diff --git a/comfyui-hydit/hydit/modules/models_comfyui.py b/comfyui-hydit/hydit/modules/models_comfyui.py new file mode 100644 index 0000000..eb2c8a3 --- /dev/null +++ b/comfyui-hydit/hydit/modules/models_comfyui.py @@ -0,0 +1,423 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from diffusers.configuration_utils import ConfigMixin, register_to_config +from diffusers.models import ModelMixin +from timm.models.vision_transformer import Mlp + +from .attn_layers import Attention, FlashCrossMHAModified, FlashSelfMHAModified, CrossAttention +from .embedders import TimestepEmbedder, PatchEmbed, timestep_embedding +from .norm_layers import RMSNorm +from .poolers import AttentionPool +from .posemb_layers import get_2d_rotary_pos_embed, get_fill_resize_and_crop + +def modulate(x, shift, scale): + return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1) + + +class FP32_Layernorm(nn.LayerNorm): + def forward(self, inputs: torch.Tensor) -> torch.Tensor: + origin_dtype = inputs.dtype + return F.layer_norm(inputs.float(), self.normalized_shape, self.weight.float(), self.bias.float(), + self.eps).to(origin_dtype) + + +class FP32_SiLU(nn.SiLU): + def forward(self, inputs: torch.Tensor) -> torch.Tensor: + return torch.nn.functional.silu(inputs.float(), inplace=False).to(inputs.dtype) + + +class HunYuanDiTBlock(nn.Module): + """ + A HunYuanDiT block with `add` conditioning. + """ + def __init__(self, + hidden_size, + c_emb_size, + num_heads, + mlp_ratio=4.0, + text_states_dim=1024, + use_flash_attn=False, + qk_norm=False, + norm_type="layer", + skip=False, + ): + super().__init__() + self.use_flash_attn = use_flash_attn + use_ele_affine = True + + if norm_type == "layer": + norm_layer = FP32_Layernorm + elif norm_type == "rms": + norm_layer = RMSNorm + else: + raise ValueError(f"Unknown norm_type: {norm_type}") + + # ========================= Self-Attention ========================= + self.norm1 = norm_layer(hidden_size, elementwise_affine=use_ele_affine, eps=1e-6) + if use_flash_attn: + self.attn1 = FlashSelfMHAModified(hidden_size, num_heads=num_heads, qkv_bias=True, qk_norm=qk_norm) + else: + self.attn1 = Attention(hidden_size, num_heads=num_heads, qkv_bias=True, qk_norm=qk_norm) + + # ========================= FFN ========================= + self.norm2 = norm_layer(hidden_size, elementwise_affine=use_ele_affine, eps=1e-6) + mlp_hidden_dim = int(hidden_size * mlp_ratio) + approx_gelu = lambda: nn.GELU(approximate="tanh") + self.mlp = Mlp(in_features=hidden_size, hidden_features=mlp_hidden_dim, act_layer=approx_gelu, drop=0) + + # ========================= Add ========================= + # Simply use add like SDXL. + self.default_modulation = nn.Sequential( + FP32_SiLU(), + nn.Linear(c_emb_size, hidden_size, bias=True) + ) + + # ========================= Cross-Attention ========================= + if use_flash_attn: + self.attn2 = FlashCrossMHAModified(hidden_size, text_states_dim, num_heads=num_heads, qkv_bias=True, + qk_norm=qk_norm) + else: + self.attn2 = CrossAttention(hidden_size, text_states_dim, num_heads=num_heads, qkv_bias=True, + qk_norm=qk_norm) + self.norm3 = norm_layer(hidden_size, elementwise_affine=True, eps=1e-6) + + # ========================= Skip Connection ========================= + if skip: + self.skip_norm = norm_layer(2 * hidden_size, elementwise_affine=True, eps=1e-6) + self.skip_linear = nn.Linear(2 * hidden_size, hidden_size) + else: + self.skip_linear = None + + def forward(self, x, c=None, text_states=None, freq_cis_img=None, skip=None): + # Long Skip Connection + if self.skip_linear is not None: + cat = torch.cat([x, skip], dim=-1) + cat = self.skip_norm(cat) + x = self.skip_linear(cat) + + # Self-Attention + shift_msa = self.default_modulation(c).unsqueeze(dim=1) + attn_inputs = ( + self.norm1(x) + shift_msa, freq_cis_img, + ) + x = x + self.attn1(*attn_inputs)[0] + + # Cross-Attention + cross_inputs = ( + self.norm3(x), text_states, freq_cis_img + ) + x = x + self.attn2(*cross_inputs)[0] + + # FFN Layer + mlp_inputs = self.norm2(x) + x = x + self.mlp(mlp_inputs) + + return x + + +class FinalLayer(nn.Module): + """ + The final layer of HunYuanDiT. + """ + def __init__(self, final_hidden_size, c_emb_size, patch_size, out_channels): + super().__init__() + self.norm_final = nn.LayerNorm(final_hidden_size, elementwise_affine=False, eps=1e-6) + self.linear = nn.Linear(final_hidden_size, patch_size * patch_size * out_channels, bias=True) + self.adaLN_modulation = nn.Sequential( + FP32_SiLU(), + nn.Linear(c_emb_size, 2 * final_hidden_size, bias=True) + ) + + def forward(self, x, c): + shift, scale = self.adaLN_modulation(c).chunk(2, dim=1) + x = modulate(self.norm_final(x), shift, scale) + x = self.linear(x) + return x + + +class HunYuanDiT(ModelMixin, ConfigMixin): + """ + HunYuanDiT: Diffusion model with a Transformer backbone. + + Inherit ModelMixin and ConfigMixin to be compatible with the sampler StableDiffusionPipeline of diffusers. + + Parameters + ---------- + args: argparse.Namespace + The arguments parsed by argparse. + input_size: tuple + The size of the input image. + patch_size: int + The size of the patch. + in_channels: int + The number of input channels. + hidden_size: int + The hidden size of the transformer backbone. + depth: int + The number of transformer blocks. + num_heads: int + The number of attention heads. + mlp_ratio: float + The ratio of the hidden size of the MLP in the transformer block. + log_fn: callable + The logging function. + """ + @register_to_config + def __init__( + self, args, + input_size=(32, 32), + patch_size=2, + in_channels=4, + hidden_size=1152, + depth=28, + num_heads=16, + mlp_ratio=4.0, + log_fn=print, + **kwargs, + ): + super().__init__() + self.args = args + self.log_fn = log_fn + self.depth = depth + self.learn_sigma = args.learn_sigma + self.in_channels = in_channels + self.out_channels = in_channels * 2 if args.learn_sigma else in_channels + self.patch_size = patch_size + self.num_heads = num_heads + self.hidden_size = hidden_size + self.text_states_dim = args.text_states_dim + self.text_states_dim_t5 = args.text_states_dim_t5 + self.text_len = args.text_len + self.text_len_t5 = args.text_len_t5 + self.norm = args.norm + + use_flash_attn = args.infer_mode == 'fa' + if use_flash_attn: + log_fn(f" Enable Flash Attention.") + qk_norm = True # See http://arxiv.org/abs/2302.05442 for details. + + self.mlp_t5 = nn.Sequential( + nn.Linear(self.text_states_dim_t5, self.text_states_dim_t5 * 4, bias=True), + FP32_SiLU(), + nn.Linear(self.text_states_dim_t5 * 4, self.text_states_dim, bias=True), + ) + # learnable replace + self.text_embedding_padding = nn.Parameter( + torch.randn(self.text_len + self.text_len_t5, self.text_states_dim, dtype=torch.float32)) + + # Attention pooling + self.pooler = AttentionPool(self.text_len_t5, self.text_states_dim_t5, num_heads=8, output_dim=1024) + + # Here we use a default learned embedder layer for future extension. + self.style_embedder = nn.Embedding(1, hidden_size) + + # Image size and crop size conditions + self.extra_in_dim = 256 * 6 + hidden_size + + # Text embedding for `add` + self.x_embedder = PatchEmbed(input_size, patch_size, in_channels, hidden_size) + self.t_embedder = TimestepEmbedder(hidden_size) + self.extra_in_dim += 1024 + self.extra_embedder = nn.Sequential( + nn.Linear(self.extra_in_dim, hidden_size * 4), + FP32_SiLU(), + nn.Linear(hidden_size * 4, hidden_size, bias=True), + ) + + # Image embedding + num_patches = self.x_embedder.num_patches + log_fn(f" Number of tokens: {num_patches}") + + # HUnYuanDiT Blocks + self.blocks = nn.ModuleList([ + HunYuanDiTBlock(hidden_size=hidden_size, + c_emb_size=hidden_size, + num_heads=num_heads, + mlp_ratio=mlp_ratio, + text_states_dim=self.text_states_dim, + use_flash_attn=use_flash_attn, + qk_norm=qk_norm, + norm_type=self.norm, + skip=layer > depth // 2, + ) + for layer in range(depth) + ]) + + self.final_layer = FinalLayer(hidden_size, hidden_size, patch_size, self.out_channels) + self.unpatchify_channels = self.out_channels + + self.initialize_weights() + + def forward(self, + x, + t, + encoder_hidden_states=None, + text_embedding_mask=None, + encoder_hidden_states_t5=None, + text_embedding_mask_t5=None, + image_meta_size=None, + style=None, + cos_cis_img=None, + sin_cis_img=None, + return_dict=True, + ): + """ + Forward pass of the encoder. + + Parameters + ---------- + x: torch.Tensor + (B, D, H, W) + t: torch.Tensor + (B) + encoder_hidden_states: torch.Tensor + CLIP text embedding, (B, L_clip, D) + text_embedding_mask: torch.Tensor + CLIP text embedding mask, (B, L_clip) + encoder_hidden_states_t5: torch.Tensor + T5 text embedding, (B, L_t5, D) + text_embedding_mask_t5: torch.Tensor + T5 text embedding mask, (B, L_t5) + image_meta_size: torch.Tensor + (B, 6) + style: torch.Tensor + (B) + cos_cis_img: torch.Tensor + sin_cis_img: torch.Tensor + return_dict: bool + Whether to return a dictionary. + """ + + text_states = encoder_hidden_states # 2,77,1024 + text_states_t5 = encoder_hidden_states_t5 # 2,256,2048 + text_states_mask = text_embedding_mask.bool() # 2,77 + text_states_t5_mask = text_embedding_mask_t5.bool() # 2,256 + b_t5, l_t5, c_t5 = text_states_t5.shape + text_states_t5 = self.mlp_t5(text_states_t5.view(-1, c_t5)) + text_states = torch.cat([text_states, text_states_t5.view(b_t5, l_t5, -1)], dim=1) # 2,205,1024 + clip_t5_mask = torch.cat([text_states_mask, text_states_t5_mask], dim=-1) + + clip_t5_mask = clip_t5_mask + text_states = torch.where(clip_t5_mask.unsqueeze(2), text_states, self.text_embedding_padding.to(text_states)) + + _, _, oh, ow = x.shape + th, tw = oh // self.patch_size, ow // self.patch_size + + # ========================= Build time and image embedding ========================= + t = self.t_embedder(t) + x = self.x_embedder(x) + + # Get image RoPE embedding according to `reso`lution. + freqs_cis_img = (cos_cis_img, sin_cis_img) + + # ========================= Concatenate all extra vectors ========================= + # Build text tokens with pooling + extra_vec = self.pooler(encoder_hidden_states_t5) + + # Build image meta size tokens + image_meta_size = timestep_embedding(image_meta_size.view(-1), 256) # [B * 6, 256] + if self.args.use_fp16: + image_meta_size = image_meta_size.half() + image_meta_size = image_meta_size.view(-1, 6 * 256) + extra_vec = torch.cat([extra_vec, image_meta_size], dim=1) # [B, D + 6 * 256] + + # Build style tokens + style_embedding = self.style_embedder(style) + extra_vec = torch.cat([extra_vec, style_embedding], dim=1) + + # Concatenate all extra vectors + c = t + self.extra_embedder(extra_vec) # [B, D] + + # ========================= Forward pass through HunYuanDiT blocks ========================= + skips = [] + for layer, block in enumerate(self.blocks): + if layer > self.depth // 2: + skip = skips.pop() + x = block(x, c, text_states, freqs_cis_img, skip) # (N, L, D) + else: + x = block(x, c, text_states, freqs_cis_img) # (N, L, D) + + if layer < (self.depth // 2 - 1): + skips.append(x) + + # ========================= Final layer ========================= + x = self.final_layer(x, c) # (N, L, patch_size ** 2 * out_channels) + x = self.unpatchify(x, th, tw) # (N, out_channels, H, W) + + if return_dict: + return {'x': x} + return x + + def calc_rope(self, height, width): + """ + Probably not the best in terms of perf to have this here + """ + th = height // 8 // self.patch_size + tw = width // 8 // self.patch_size + base_size = 512 // 8 // self.patch_size + start, stop = get_fill_resize_and_crop((th, tw), base_size) + sub_args = [start, stop, (th, tw)] + head_size = self.hidden_size // self.num_heads + rope = get_2d_rotary_pos_embed(head_size, *sub_args) + return rope + + def initialize_weights(self): + # Initialize transformer layers: + def _basic_init(module): + if isinstance(module, nn.Linear): + torch.nn.init.xavier_uniform_(module.weight) + if module.bias is not None: + nn.init.constant_(module.bias, 0) + self.apply(_basic_init) + + # Initialize patch_embed like nn.Linear (instead of nn.Conv2d): + w = self.x_embedder.proj.weight.data + nn.init.xavier_uniform_(w.view([w.shape[0], -1])) + nn.init.constant_(self.x_embedder.proj.bias, 0) + + # Initialize label embedding table: + nn.init.normal_(self.extra_embedder[0].weight, std=0.02) + nn.init.normal_(self.extra_embedder[2].weight, std=0.02) + + # Initialize timestep embedding MLP: + nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02) + nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02) + + # Zero-out adaLN modulation layers in HunYuanDiT blocks: + for block in self.blocks: + nn.init.constant_(block.default_modulation[-1].weight, 0) + nn.init.constant_(block.default_modulation[-1].bias, 0) + + # Zero-out output layers: + nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0) + nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0) + nn.init.constant_(self.final_layer.linear.weight, 0) + nn.init.constant_(self.final_layer.linear.bias, 0) + + def unpatchify(self, x, h, w): + """ + x: (N, T, patch_size**2 * C) + imgs: (N, H, W, C) + """ + c = self.unpatchify_channels + p = self.x_embedder.patch_size[0] + # h = w = int(x.shape[1] ** 0.5) + assert h * w == x.shape[1] + + x = x.reshape(shape=(x.shape[0], h, w, p, p, c)) + x = torch.einsum('nhwpqc->nchpwq', x) + imgs = x.reshape(shape=(x.shape[0], c, h * p, w * p)) + return imgs + + +################################################################################# +# HunYuanDiT Configs # +################################################################################# + +HUNYUAN_DIT_CONFIG = { + 'DiT-g/2': {'depth': 40, 'hidden_size': 1408, 'patch_size': 2, 'num_heads': 16, 'mlp_ratio': 4.3637}, + 'DiT-XL/2': {'depth': 28, 'hidden_size': 1152, 'patch_size': 2, 'num_heads': 16}, + 'DiT-L/2': {'depth': 24, 'hidden_size': 1024, 'patch_size': 2, 'num_heads': 16}, + 'DiT-B/2': {'depth': 12, 'hidden_size': 768, 'patch_size': 2, 'num_heads': 12}, +} \ No newline at end of file diff --git a/comfyui-hydit/hydit/modules/norm_layers.py b/comfyui-hydit/hydit/modules/norm_layers.py new file mode 100644 index 0000000..5204ad9 --- /dev/null +++ b/comfyui-hydit/hydit/modules/norm_layers.py @@ -0,0 +1,68 @@ +import torch +import torch.nn as nn + + +class RMSNorm(nn.Module): + def __init__(self, dim: int, elementwise_affine=True, eps: float = 1e-6): + """ + Initialize the RMSNorm normalization layer. + + Args: + dim (int): The dimension of the input tensor. + eps (float, optional): A small value added to the denominator for numerical stability. Default is 1e-6. + + Attributes: + eps (float): A small value added to the denominator for numerical stability. + weight (nn.Parameter): Learnable scaling parameter. + + """ + super().__init__() + self.eps = eps + if elementwise_affine: + self.weight = nn.Parameter(torch.ones(dim)) + + def _norm(self, x): + """ + Apply the RMSNorm normalization to the input tensor. + + Args: + x (torch.Tensor): The input tensor. + + Returns: + torch.Tensor: The normalized tensor. + + """ + return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) + + def forward(self, x): + """ + Forward pass through the RMSNorm layer. + + Args: + x (torch.Tensor): The input tensor. + + Returns: + torch.Tensor: The output tensor after applying RMSNorm. + + """ + output = self._norm(x.float()).type_as(x) + if hasattr(self, "weight"): + output = output * self.weight + return output + + +class GroupNorm32(nn.GroupNorm): + def __init__(self, num_groups, num_channels, eps=1e-5, dtype=None): + super().__init__(num_groups=num_groups, num_channels=num_channels, eps=eps, dtype=dtype) + + def forward(self, x): + y = super().forward(x).to(x.dtype) + return y + +def normalization(channels, dtype=None): + """ + Make a standard normalization layer. + :param channels: number of input channels. + :return: an nn.Module for normalization. + """ + return GroupNorm32(num_channels=channels, num_groups=32, dtype=dtype) diff --git a/comfyui-hydit/hydit/modules/poolers.py b/comfyui-hydit/hydit/modules/poolers.py new file mode 100644 index 0000000..a4adcac --- /dev/null +++ b/comfyui-hydit/hydit/modules/poolers.py @@ -0,0 +1,39 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class AttentionPool(nn.Module): + def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None): + super().__init__() + self.positional_embedding = nn.Parameter(torch.randn(spacial_dim + 1, embed_dim) / embed_dim ** 0.5) + self.k_proj = nn.Linear(embed_dim, embed_dim) + self.q_proj = nn.Linear(embed_dim, embed_dim) + self.v_proj = nn.Linear(embed_dim, embed_dim) + self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim) + self.num_heads = num_heads + + def forward(self, x): + x = x.permute(1, 0, 2) # NLC -> LNC + x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (L+1)NC + x = x + self.positional_embedding[:, None, :].to(x.dtype) # (L+1)NC + x, _ = F.multi_head_attention_forward( + query=x[:1], key=x, value=x, + embed_dim_to_check=x.shape[-1], + num_heads=self.num_heads, + q_proj_weight=self.q_proj.weight, + k_proj_weight=self.k_proj.weight, + v_proj_weight=self.v_proj.weight, + in_proj_weight=None, + in_proj_bias=torch.cat([self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]), + bias_k=None, + bias_v=None, + add_zero_attn=False, + dropout_p=0, + out_proj_weight=self.c_proj.weight, + out_proj_bias=self.c_proj.bias, + use_separate_proj_weight=True, + training=self.training, + need_weights=False + ) + return x.squeeze(0) diff --git a/comfyui-hydit/hydit/modules/posemb_layers.py b/comfyui-hydit/hydit/modules/posemb_layers.py new file mode 100644 index 0000000..62c83df --- /dev/null +++ b/comfyui-hydit/hydit/modules/posemb_layers.py @@ -0,0 +1,225 @@ +import torch +import numpy as np +from typing import Union + + +def _to_tuple(x): + if isinstance(x, int): + return x, x + else: + return x + + +def get_fill_resize_and_crop(src, tgt): # src 来源的分辨率 tgt base 分辨率 + th, tw = _to_tuple(tgt) + h, w = _to_tuple(src) + + tr = th / tw # base 分辨率 + r = h / w # 目标分辨率 + + # resize + if r > tr: + resize_height = th + resize_width = int(round(th / h * w)) + else: + resize_width = tw + resize_height = int(round(tw / w * h)) # 根据base分辨率,将目标分辨率resize下来 + + crop_top = int(round((th - resize_height) / 2.0)) + crop_left = int(round((tw - resize_width) / 2.0)) + + return (crop_top, crop_left), (crop_top + resize_height, crop_left + resize_width) + + +def get_meshgrid(start, *args): + if len(args) == 0: + # start is grid_size + num = _to_tuple(start) + start = (0, 0) + stop = num + elif len(args) == 1: + # start is start, args[0] is stop, step is 1 + start = _to_tuple(start) + stop = _to_tuple(args[0]) + num = (stop[0] - start[0], stop[1] - start[1]) + elif len(args) == 2: + # start is start, args[0] is stop, args[1] is num + start = _to_tuple(start) # 左上角 eg: 12,0 + stop = _to_tuple(args[0]) # 右下角 eg: 20,32 + num = _to_tuple(args[1]) # 目标大小 eg: 32,124 + else: + raise ValueError(f"len(args) should be 0, 1 or 2, but got {len(args)}") + + grid_h = np.linspace(start[0], stop[0], num[0], endpoint=False, dtype=np.float32) # 12-20 中间差值32份 0-32 中间差值124份 + grid_w = np.linspace(start[1], stop[1], num[1], endpoint=False, dtype=np.float32) + grid = np.meshgrid(grid_w, grid_h) # here w goes first + grid = np.stack(grid, axis=0) # [2, W, H] + return grid + +################################################################################# +# Sine/Cosine Positional Embedding Functions # +################################################################################# +# https://github.com/facebookresearch/mae/blob/main/util/pos_embed.py + +def get_2d_sincos_pos_embed(embed_dim, start, *args, cls_token=False, extra_tokens=0): + """ + grid_size: int of the grid height and width + return: + pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token) + """ + grid = get_meshgrid(start, *args) # [2, H, w] + # grid_h = np.arange(grid_size, dtype=np.float32) + # grid_w = np.arange(grid_size, dtype=np.float32) + # grid = np.meshgrid(grid_w, grid_h) # here w goes first + # grid = np.stack(grid, axis=0) # [2, W, H] + + grid = grid.reshape([2, 1, *grid.shape[1:]]) + pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid) + if cls_token and extra_tokens > 0: + pos_embed = np.concatenate([np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0) + return pos_embed + + +def get_2d_sincos_pos_embed_from_grid(embed_dim, grid): + assert embed_dim % 2 == 0 + + # use half of dimensions to encode grid_h + emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2) + emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2) + + emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D) + return emb + + +def get_1d_sincos_pos_embed_from_grid(embed_dim, pos): + """ + embed_dim: output dimension for each position + pos: a list of positions to be encoded: size (W,H) + out: (M, D) + """ + assert embed_dim % 2 == 0 + omega = np.arange(embed_dim // 2, dtype=np.float64) + omega /= embed_dim / 2. + omega = 1. / 10000**omega # (D/2,) + + pos = pos.reshape(-1) # (M,) + out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product + + emb_sin = np.sin(out) # (M, D/2) + emb_cos = np.cos(out) # (M, D/2) + + emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D) + return emb + + +################################################################################# +# Rotary Positional Embedding Functions # +################################################################################# +# https://github.com/facebookresearch/llama/blob/main/llama/model.py#L443 + +def get_2d_rotary_pos_embed(embed_dim, start, *args, use_real=True): + """ + This is a 2d version of precompute_freqs_cis, which is a RoPE for image tokens with 2d structure. + + Parameters + ---------- + embed_dim: int + embedding dimension size + start: int or tuple of int + If len(args) == 0, start is num; If len(args) == 1, start is start, args[0] is stop, step is 1; + If len(args) == 2, start is start, args[0] is stop, args[1] is num. + use_real: bool + If True, return real part and imaginary part separately. Otherwise, return complex numbers. + + Returns + ------- + pos_embed: torch.Tensor + [HW, D/2] + """ + grid = get_meshgrid(start, *args) # [2, H, w] + grid = grid.reshape([2, 1, *grid.shape[1:]]) # 返回一个采样矩阵 分辨率与目标分辨率一致 + pos_embed = get_2d_rotary_pos_embed_from_grid(embed_dim, grid, use_real=use_real) + return pos_embed + + +def get_2d_rotary_pos_embed_from_grid(embed_dim, grid, use_real=False): + assert embed_dim % 4 == 0 + + # use half of dimensions to encode grid_h + emb_h = get_1d_rotary_pos_embed(embed_dim // 2, grid[0].reshape(-1), use_real=use_real) # (H*W, D/4) + emb_w = get_1d_rotary_pos_embed(embed_dim // 2, grid[1].reshape(-1), use_real=use_real) # (H*W, D/4) + + if use_real: + cos = torch.cat([emb_h[0], emb_w[0]], dim=1) # (H*W, D/2) + sin = torch.cat([emb_h[1], emb_w[1]], dim=1) # (H*W, D/2) + return cos, sin + else: + emb = torch.cat([emb_h, emb_w], dim=1) # (H*W, D/2) + return emb + + +def get_1d_rotary_pos_embed(dim: int, pos: Union[np.ndarray, int], theta: float = 10000.0, use_real=False): + """ + Precompute the frequency tensor for complex exponentials (cis) with given dimensions. + + This function calculates a frequency tensor with complex exponentials using the given dimension 'dim' + and the end index 'end'. The 'theta' parameter scales the frequencies. + The returned tensor contains complex values in complex64 data type. + + Args: + dim (int): Dimension of the frequency tensor. + pos (np.ndarray, int): Position indices for the frequency tensor. [S] or scalar + theta (float, optional): Scaling factor for frequency computation. Defaults to 10000.0. + use_real (bool, optional): If True, return real part and imaginary part separately. + Otherwise, return complex numbers. + + Returns: + torch.Tensor: Precomputed frequency tensor with complex exponentials. [S, D/2] + + """ + if isinstance(pos, int): + pos = np.arange(pos) + freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)) # [D/2] + t = torch.from_numpy(pos).to(freqs.device) # type: ignore # [S] + freqs = torch.outer(t, freqs).float() # type: ignore # [S, D/2] + if use_real: + freqs_cos = freqs.cos().repeat_interleave(2, dim=1) # [S, D] + freqs_sin = freqs.sin().repeat_interleave(2, dim=1) # [S, D] + return freqs_cos, freqs_sin + else: + freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64 # [S, D/2] + return freqs_cis + + + +def calc_sizes(rope_img, patch_size, th, tw): + """ 计算 RoPE 的尺寸. """ + if rope_img == 'extend': + # 拓展模式 + sub_args = [(th, tw)] + elif rope_img.startswith('base'): + # 基于一个尺寸, 其他尺寸插值获得. + base_size = int(rope_img[4:]) // 8 // patch_size # 基于512作为base,其他根据512差值得到 + start, stop = get_fill_resize_and_crop((th, tw), base_size) # 需要在32x32里面 crop的左上角和右下角 + sub_args = [start, stop, (th, tw)] + else: + raise ValueError(f"Unknown rope_img: {rope_img}") + return sub_args + + +def init_image_posemb(rope_img, + resolutions, + patch_size, + hidden_size, + num_heads, + log_fn, + rope_real=True, + ): + freqs_cis_img = {} + for reso in resolutions: + th, tw = reso.height // 8 // patch_size, reso.width // 8 // patch_size + sub_args = calc_sizes(rope_img, patch_size, th, tw) # [左上角, 右下角, 目标高宽] 需要在32x32里面 crop的左上角和右下角 + freqs_cis_img[str(reso)] = get_2d_rotary_pos_embed(hidden_size // num_heads, *sub_args, use_real=rope_real) + log_fn(f" Using image RoPE ({rope_img}) ({'real' if rope_real else 'complex'}): {sub_args} | ({reso}) " + f"{freqs_cis_img[str(reso)][0].shape if rope_real else freqs_cis_img[str(reso)].shape}") + return freqs_cis_img diff --git a/comfyui-hydit/hydit/modules/text_encoder.py b/comfyui-hydit/hydit/modules/text_encoder.py new file mode 100644 index 0000000..7a16b21 --- /dev/null +++ b/comfyui-hydit/hydit/modules/text_encoder.py @@ -0,0 +1,95 @@ +import torch +import torch.nn as nn +from transformers import AutoTokenizer, T5EncoderModel, T5ForConditionalGeneration + + +class MT5Embedder(nn.Module): + available_models = ["t5-v1_1-xxl"] + + def __init__( + self, + model_dir="t5-v1_1-xxl", + model_kwargs=None, + torch_dtype=None, + use_tokenizer_only=False, + conditional_generation=False, + max_length=128, + ): + super().__init__() + self.device = "cuda" if torch.cuda.is_available() else "cpu" + self.torch_dtype = torch_dtype or torch.bfloat16 + self.max_length = max_length + if model_kwargs is None: + model_kwargs = { + # "low_cpu_mem_usage": True, + "torch_dtype": self.torch_dtype, + } + model_kwargs["device_map"] = {"shared": self.device, "encoder": self.device} + self.tokenizer = AutoTokenizer.from_pretrained(model_dir) + if use_tokenizer_only: + return + if conditional_generation: + self.model = None + self.generation_model = T5ForConditionalGeneration.from_pretrained( + model_dir + ) + return + self.model = T5EncoderModel.from_pretrained(model_dir, **model_kwargs).eval().to(self.torch_dtype) + + def get_tokens_and_mask(self, texts): + text_tokens_and_mask = self.tokenizer( + texts, + max_length=self.max_length, + padding="max_length", + truncation=True, + return_attention_mask=True, + add_special_tokens=True, + return_tensors="pt", + ) + tokens = text_tokens_and_mask["input_ids"][0] + mask = text_tokens_and_mask["attention_mask"][0] + # tokens = torch.tensor(tokens).clone().detach() + # mask = torch.tensor(mask, dtype=torch.bool).clone().detach() + return tokens, mask + + def get_text_embeddings(self, texts, attention_mask=True, layer_index=-1): + text_tokens_and_mask = self.tokenizer( + texts, + max_length=self.max_length, + padding="max_length", + truncation=True, + return_attention_mask=True, + add_special_tokens=True, + return_tensors="pt", + ) + + with torch.no_grad(): + outputs = self.model( + input_ids=text_tokens_and_mask["input_ids"].to(self.device), + attention_mask=text_tokens_and_mask["attention_mask"].to(self.device) + if attention_mask + else None, + output_hidden_states=True, + ) + text_encoder_embs = outputs["hidden_states"][layer_index].detach() + + return text_encoder_embs, text_tokens_and_mask["attention_mask"].to(self.device) + + @torch.no_grad() + def __call__(self, tokens, attention_mask, layer_index=-1): + with torch.cuda.amp.autocast(): + outputs = self.model( + input_ids=tokens, + attention_mask=attention_mask, + output_hidden_states=True, + ) + + z = outputs.hidden_states[layer_index].detach() + return z + + def general(self, text: str): + # input_ids = input_ids = torch.tensor([list(text.encode("utf-8"))]) + num_special_tokens + input_ids = self.tokenizer(text, max_length=128).input_ids + print(input_ids) + outputs = self.generation_model(input_ids) + return outputs \ No newline at end of file diff --git a/comfyui-hydit/hydit/utils/tools.py b/comfyui-hydit/hydit/utils/tools.py new file mode 100644 index 0000000..66c0b03 --- /dev/null +++ b/comfyui-hydit/hydit/utils/tools.py @@ -0,0 +1,17 @@ +import random + +import numpy as np +import torch + + +def set_seeds(seed_list, device=None): + if isinstance(seed_list, (tuple, list)): + seed = sum(seed_list) + else: + seed = seed_list + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + + return torch.Generator(device).manual_seed(seed) diff --git a/comfyui-hydit/img/txt2img_v2.png b/comfyui-hydit/img/txt2img_v2.png new file mode 100644 index 0000000..6ab0105 Binary files /dev/null and b/comfyui-hydit/img/txt2img_v2.png differ diff --git a/comfyui-hydit/nodes.py b/comfyui-hydit/nodes.py new file mode 100644 index 0000000..3f5a9ce --- /dev/null +++ b/comfyui-hydit/nodes.py @@ -0,0 +1,208 @@ +import copy +import os +import torch +from .utils import convert_images_to_tensors +from comfy.model_management import get_torch_device +import folder_paths +from .hydit.diffusion.pipeline import StableDiffusionPipeline +from .hydit.config import get_args +from .hydit.inference import End2End +from pathlib import Path +from .hydit.constants import SAMPLER_FACTORY +from diffusers import schedulers +from .constant import HUNYUAN_PATH, SCHEDULERS_hunyuan +from .dit import load_dit + +class DiffusersPipelineLoader: + def __init__(self): + self.tmp_dir = folder_paths.get_temp_directory() + self.dtype = torch.float32 + + + @classmethod + def INPUT_TYPES(s): + return {"required": { "pipeline_folder_name": (os.listdir(HUNYUAN_PATH), ), + "model_name": (["disable"] + folder_paths.get_filename_list("checkpoints"), ), + "vae_name": (["disable"] + folder_paths.get_filename_list("vae"), ), + "backend": (["ksampler", "diffusers"], ), }} + + RETURN_TYPES = ("PIPELINE", "MODEL", "CLIP", "VAE") + + FUNCTION = "create_pipeline" + + CATEGORY = "Diffusers" + + def create_pipeline(self, pipeline_folder_name, model_name, vae_name, backend): + if model_name != "disable": + MODEL_PATH = folder_paths.get_full_path("checkpoints", model_name) + else: + MODEL_PATH = None + if vae_name != "disable": + VAE_PATH = folder_paths.get_full_path("vae", vae_name) + else: + VAE_PATH = None + + if backend == "diffusers": + args_hunyuan = get_args() + gen = End2End(args_hunyuan[0], Path(os.path.join(HUNYUAN_PATH, pipeline_folder_name)), MODEL_PATH, VAE_PATH) + return (gen, None, None, None) + elif backend == "ksampler": + out = load_dit(model_path = os.path.join(HUNYUAN_PATH, pipeline_folder_name), MODEL_PATH = MODEL_PATH, VAE_PATH = VAE_PATH) + return (None,) + out[:3] + +class DiffusersSchedulerLoader: + def __init__(self): + self.tmp_dir = folder_paths.get_temp_directory() + self.dtype = torch.float32 + + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "scheduler_name": (list(SCHEDULERS_hunyuan), ), + } + } + + RETURN_TYPES = ("SCHEDULER",) + + FUNCTION = "load_scheduler" + + CATEGORY = "Diffusers" + + def load_scheduler(self, scheduler_name): + # Load sampler from factory + kwargs = SAMPLER_FACTORY[scheduler_name]['kwargs'] + scheduler = SAMPLER_FACTORY[scheduler_name]['scheduler'] + args_hunyuan = get_args() + args_hunyuan = args_hunyuan[0] + + # Update sampler according to the arguments + kwargs['beta_schedule'] = args_hunyuan.noise_schedule + kwargs['beta_start'] = args_hunyuan.beta_start + kwargs['beta_end'] = args_hunyuan.beta_end + kwargs['prediction_type'] = args_hunyuan.predict_type + + # Build scheduler according to the sampler. + scheduler_class = getattr(schedulers, scheduler) + scheduler = scheduler_class(**kwargs) + + + return (scheduler,) + +class DiffusersModelMakeup: + def __init__(self): + self.torch_device = get_torch_device() + + @classmethod + def INPUT_TYPES(s): + return { + "required": { + "pipeline": ("PIPELINE", ), + "scheduler": ("SCHEDULER", ), + }, + } + + RETURN_TYPES = ("MAKED_PIPELINE",) + + FUNCTION = "makeup_pipeline" + + CATEGORY = "Diffusers" + + def makeup_pipeline(self, pipeline, scheduler): + progress_bar_config = {} + + pipe = StableDiffusionPipeline(vae=pipeline.vae, + text_encoder=pipeline.clip_text_encoder, + tokenizer=pipeline.tokenizer, + unet=pipeline.model, + scheduler=scheduler, + feature_extractor=None, + safety_checker=None, + requires_safety_checker=False, + progress_bar_config=progress_bar_config, + embedder_t5=pipeline.embedder_t5, + infer_mode=pipeline.infer_mode, + ) + + pipe = pipe.to(pipeline.device) + pipeline.pipeline = pipe + return (pipeline,) + +class DiffusersClipTextEncode: + @classmethod + def INPUT_TYPES(s): + return {"required": { + "positive": ("STRING", {"multiline": True}), + "negative": ("STRING", {"multiline": True}), + }} + + RETURN_TYPES = ("STRINGC", "STRINGC", ) + RETURN_NAMES = ("positive", "negative", ) + + FUNCTION = "concat_embeds" + + CATEGORY = "Diffusers" + + def concat_embeds(self, positive, negative): + + return (positive, negative, ) + + +class DiffusersSampler: + def __init__(self): + self.torch_device = get_torch_device() + + @classmethod + def INPUT_TYPES(s): + return {"required": { + "maked_pipeline": ("MAKED_PIPELINE", ), + "positive": ("STRINGC",), + "negative": ("STRINGC",), + "batch_size": ("INT", {"default": 1, "min": 1, "max": 128, "step": 1}), + "width": ("INT", {"default": 1024, "min": 1, "max": 8192, "step": 1}), + "height": ("INT", {"default": 1024, "min": 1, "max": 8192, "step": 1}), + "steps": ("INT", {"default": 20, "min": 1, "max": 10000}), + "cfg": ("FLOAT", {"default": 8.0, "min": 0.0, "max": 100.0, "step":0.1, "round": 0.01}), + "seed": ("INT", {"default": 0, "min": 0, "max": 2**32-1}), + }} + + RETURN_TYPES = ("IMAGE",) + + FUNCTION = "sample" + + CATEGORY = "Diffusers" + + def sample(self, maked_pipeline, positive, negative, batch_size, width, height, steps, cfg, seed): + + results = maked_pipeline.predict(positive, + height=height, + width=width, + seed=int(seed), + enhanced_prompt=None, + negative_prompt=negative, + infer_steps=steps, + guidance_scale=cfg, + batch_size=batch_size, + src_size_cond=[height, width], + ) + images = results['images'] + return (convert_images_to_tensors(images),) + + + + +NODE_CLASS_MAPPINGS = { + "DiffusersPipelineLoader": DiffusersPipelineLoader, + "DiffusersSchedulerLoader": DiffusersSchedulerLoader, + "DiffusersModelMakeup": DiffusersModelMakeup, + "DiffusersClipTextEncode": DiffusersClipTextEncode, + "DiffusersSampler": DiffusersSampler, +} + +NODE_DISPLAY_NAME_MAPPINGS = { + "DiffusersPipelineLoader": "HunYuan Pipeline Loader", + "DiffusersSchedulerLoader": "HunYuan Scheduler Loader", + "DiffusersModelMakeup": "HunYuan Model Makeup", + "DiffusersClipTextEncode": "HunYuan Clip Text Encode", + "DiffusersSampler": "HunYuan Sampler", +} diff --git a/comfyui-hydit/requirements.txt b/comfyui-hydit/requirements.txt new file mode 100644 index 0000000..0834b24 --- /dev/null +++ b/comfyui-hydit/requirements.txt @@ -0,0 +1,18 @@ +--extra-index-url https://pypi.ngc.nvidia.com +--extra-index-url https://download.pytorch.org/whl/cu117 +timm +diffusers +peft +protobuf +accelerate +loguru +sentencepiece +cuda-python +polygraphy +pandas +omegaconf +torch==2.0.1 +torchvision==0.15.2 +torchaudio==2.0.2 +xformers==0.0.20 +pytorch_lightning \ No newline at end of file diff --git a/comfyui-hydit/supported_dit_models.py b/comfyui-hydit/supported_dit_models.py new file mode 100644 index 0000000..8d26855 --- /dev/null +++ b/comfyui-hydit/supported_dit_models.py @@ -0,0 +1,93 @@ +import comfy.supported_models_base +import comfy.latent_formats +import comfy.model_patcher +import comfy.model_base +import comfy.utils +import torch +from collections import namedtuple +from .hydit.modules.models_comfyui import HunYuanDiT as HYDiT + +def batch_embeddings(embeds, batch_size): + bs_embed, seq_len, _ = embeds.shape + embeds = embeds.repeat(1, batch_size, 1) + embeds = embeds.view(bs_embed * batch_size, seq_len, -1) + return embeds + +class HunYuan_DiT(comfy.supported_models_base.BASE): + Conf = namedtuple('DiT', ['learn_sigma', 'text_states_dim', 'text_states_dim_t5', 'text_len', 'text_len_t5', 'norm', 'infer_mode', 'use_fp16']) + conf = { + 'learn_sigma': True, + 'text_states_dim': 1024, + 'text_states_dim_t5': 2048, + 'text_len': 77, + 'text_len_t5': 256, + 'norm': 'layer', + 'infer_mode': 'torch', + 'use_fp16': True + } + + unet_config = {} + unet_extra_config = { + "num_heads": 16 + } + latent_format = comfy.latent_formats.SDXL + + dit_conf = Conf(**conf) + + def __init__(self, model_conf): + self.unet_config = model_conf.get("unet_config", {}) + #print(model_conf) + print(self.unet_config) + self.sampling_settings = model_conf.get("sampling_settings", {}) + self.latent_format = self.latent_format() + self.unet_config["disable_unet_model_creation"] = True + #self.unet_config["disable_unet_model_creation"] = self.unet_config.get("disable_unet_model_creation", True) + + def model_type(self, state_dict, prefix=""): + return comfy.model_base.ModelType.V_PREDICTION + +class HYDiT_Model(comfy.model_base.BaseModel): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def extra_conds(self, **kwargs): + out = super().extra_conds(**kwargs) + addit_embeds = kwargs['cross_attn'].addit_embeds + for name in addit_embeds: + out[name] = comfy.conds.CONDRegular(addit_embeds[name]) + + return out + +class ModifiedHunYuanDiT(HYDiT): + def forward_core(self, *args, **kwargs): + return super().forward(*args, **kwargs) + + def forward(self, x, timesteps, context, t5_embeds=None, attention_mask=None, t5_attention_mask=None, image_meta_size=None, **kwargs): + batch_size, _, height, width = x.shape + #assert(0) + + + style = torch.as_tensor([0, 0] * (batch_size//2), device=x.device) + src_size_cond = (width//2*16, height//2*16) + size_cond = list(src_size_cond) + [width*8, height*8, 0, 0] + image_meta_size = torch.as_tensor([size_cond] * batch_size, device=x.device) + rope = self.calc_rope(*src_size_cond) + + + noise_pred = self.forward_core( + x = x.to(self.dtype), + t = timesteps.to(self.dtype), + encoder_hidden_states = context.to(self.dtype), + text_embedding_mask = attention_mask.to(self.dtype), + encoder_hidden_states_t5 = t5_embeds.to(self.dtype), + text_embedding_mask_t5 = t5_attention_mask.to(self.dtype), + image_meta_size = image_meta_size.to(self.dtype), + style = style, + cos_cis_img = rope[0], + sin_cis_img = rope[1], + return_dict=False + ) + noise_pred = noise_pred.to(torch.float) + eps, _ = noise_pred[:, :self.in_channels], noise_pred[:, self.in_channels:] + return eps + diff --git a/comfyui-hydit/utils.py b/comfyui-hydit/utils.py new file mode 100644 index 0000000..b41cb35 --- /dev/null +++ b/comfyui-hydit/utils.py @@ -0,0 +1,215 @@ +import io +import torch +import requests +import numpy as np +from PIL import Image +from omegaconf import OmegaConf +from torchvision.transforms import ToTensor +from diffusers.pipelines.stable_diffusion.convert_from_ckpt import ( + assign_to_checkpoint, + conv_attn_to_linear, + create_vae_diffusers_config, + renew_vae_attention_paths, + renew_vae_resnet_paths, +) +from diffusers import ( + AutoencoderKL, + DDIMScheduler, + DDPMScheduler, + DEISMultistepScheduler, + DPMSolverMultistepScheduler, + DPMSolverSinglestepScheduler, + EulerAncestralDiscreteScheduler, + EulerDiscreteScheduler, + HeunDiscreteScheduler, + KDPM2AncestralDiscreteScheduler, + KDPM2DiscreteScheduler, + UniPCMultistepScheduler, +) + +SCHEDULERS = { + 'DDIM' : DDIMScheduler, + 'DDPM' : DDPMScheduler, + 'DEISMultistep' : DEISMultistepScheduler, + 'DPMSolverMultistep' : DPMSolverMultistepScheduler, + 'DPMSolverSinglestep' : DPMSolverSinglestepScheduler, + 'EulerAncestralDiscrete' : EulerAncestralDiscreteScheduler, + 'EulerDiscrete' : EulerDiscreteScheduler, + 'HeunDiscrete' : HeunDiscreteScheduler, + 'KDPM2AncestralDiscrete' : KDPM2AncestralDiscreteScheduler, + 'KDPM2Discrete' : KDPM2DiscreteScheduler, + 'UniPCMultistep' : UniPCMultistepScheduler +} + +SCHEDULERS_hunyuan = ["ddpm", "ddim", "dpmms"] + +def token_auto_concat_embeds(pipe, positive, negative): + max_length = pipe.tokenizer.model_max_length + positive_length = pipe.tokenizer(positive, return_tensors="pt").input_ids.shape[-1] + negative_length = pipe.tokenizer(negative, return_tensors="pt").input_ids.shape[-1] + + print(f'Token length is model maximum: {max_length}, positive length: {positive_length}, negative length: {negative_length}.') + if max_length < positive_length or max_length < negative_length: + print('Concatenated embedding.') + if positive_length > negative_length: + positive_ids = pipe.tokenizer(positive, return_tensors="pt").input_ids.to("cuda") + negative_ids = pipe.tokenizer(negative, truncation=False, padding="max_length", max_length=positive_ids.shape[-1], return_tensors="pt").input_ids.to("cuda") + else: + negative_ids = pipe.tokenizer(negative, return_tensors="pt").input_ids.to("cuda") + positive_ids = pipe.tokenizer(positive, truncation=False, padding="max_length", max_length=negative_ids.shape[-1], return_tensors="pt").input_ids.to("cuda") + else: + positive_ids = pipe.tokenizer(positive, truncation=False, padding="max_length", max_length=max_length, return_tensors="pt").input_ids.to("cuda") + negative_ids = pipe.tokenizer(negative, truncation=False, padding="max_length", max_length=max_length, return_tensors="pt").input_ids.to("cuda") + + positive_concat_embeds = [] + negative_concat_embeds = [] + for i in range(0, positive_ids.shape[-1], max_length): + positive_concat_embeds.append(pipe.text_encoder(positive_ids[:, i: i + max_length])[0]) + negative_concat_embeds.append(pipe.text_encoder(negative_ids[:, i: i + max_length])[0]) + + positive_prompt_embeds = torch.cat(positive_concat_embeds, dim=1) + negative_prompt_embeds = torch.cat(negative_concat_embeds, dim=1) + return positive_prompt_embeds, negative_prompt_embeds + +# Reference from : https://github.com/huggingface/diffusers/blob/main/scripts/convert_vae_pt_to_diffusers.py +def custom_convert_ldm_vae_checkpoint(checkpoint, config): + vae_state_dict = checkpoint + + new_checkpoint = {} + + new_checkpoint["encoder.conv_in.weight"] = vae_state_dict["encoder.conv_in.weight"] + new_checkpoint["encoder.conv_in.bias"] = vae_state_dict["encoder.conv_in.bias"] + new_checkpoint["encoder.conv_out.weight"] = vae_state_dict["encoder.conv_out.weight"] + new_checkpoint["encoder.conv_out.bias"] = vae_state_dict["encoder.conv_out.bias"] + new_checkpoint["encoder.conv_norm_out.weight"] = vae_state_dict["encoder.norm_out.weight"] + new_checkpoint["encoder.conv_norm_out.bias"] = vae_state_dict["encoder.norm_out.bias"] + + new_checkpoint["decoder.conv_in.weight"] = vae_state_dict["decoder.conv_in.weight"] + new_checkpoint["decoder.conv_in.bias"] = vae_state_dict["decoder.conv_in.bias"] + new_checkpoint["decoder.conv_out.weight"] = vae_state_dict["decoder.conv_out.weight"] + new_checkpoint["decoder.conv_out.bias"] = vae_state_dict["decoder.conv_out.bias"] + new_checkpoint["decoder.conv_norm_out.weight"] = vae_state_dict["decoder.norm_out.weight"] + new_checkpoint["decoder.conv_norm_out.bias"] = vae_state_dict["decoder.norm_out.bias"] + + new_checkpoint["quant_conv.weight"] = vae_state_dict["quant_conv.weight"] + new_checkpoint["quant_conv.bias"] = vae_state_dict["quant_conv.bias"] + new_checkpoint["post_quant_conv.weight"] = vae_state_dict["post_quant_conv.weight"] + new_checkpoint["post_quant_conv.bias"] = vae_state_dict["post_quant_conv.bias"] + + # Retrieves the keys for the encoder down blocks only + num_down_blocks = len({".".join(layer.split(".")[:3]) for layer in vae_state_dict if "encoder.down" in layer}) + down_blocks = { + layer_id: [key for key in vae_state_dict if f"down.{layer_id}" in key] for layer_id in range(num_down_blocks) + } + + # Retrieves the keys for the decoder up blocks only + num_up_blocks = len({".".join(layer.split(".")[:3]) for layer in vae_state_dict if "decoder.up" in layer}) + up_blocks = { + layer_id: [key for key in vae_state_dict if f"up.{layer_id}" in key] for layer_id in range(num_up_blocks) + } + + for i in range(num_down_blocks): + resnets = [key for key in down_blocks[i] if f"down.{i}" in key and f"down.{i}.downsample" not in key] + + if f"encoder.down.{i}.downsample.conv.weight" in vae_state_dict: + new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.weight"] = vae_state_dict.pop( + f"encoder.down.{i}.downsample.conv.weight" + ) + new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.bias"] = vae_state_dict.pop( + f"encoder.down.{i}.downsample.conv.bias" + ) + + paths = renew_vae_resnet_paths(resnets) + meta_path = {"old": f"down.{i}.block", "new": f"down_blocks.{i}.resnets"} + assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) + + mid_resnets = [key for key in vae_state_dict if "encoder.mid.block" in key] + num_mid_res_blocks = 2 + for i in range(1, num_mid_res_blocks + 1): + resnets = [key for key in mid_resnets if f"encoder.mid.block_{i}" in key] + + paths = renew_vae_resnet_paths(resnets) + meta_path = {"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"} + assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) + + mid_attentions = [key for key in vae_state_dict if "encoder.mid.attn" in key] + paths = renew_vae_attention_paths(mid_attentions) + meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"} + assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) + conv_attn_to_linear(new_checkpoint) + + for i in range(num_up_blocks): + block_id = num_up_blocks - 1 - i + resnets = [ + key for key in up_blocks[block_id] if f"up.{block_id}" in key and f"up.{block_id}.upsample" not in key + ] + + if f"decoder.up.{block_id}.upsample.conv.weight" in vae_state_dict: + new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.weight"] = vae_state_dict[ + f"decoder.up.{block_id}.upsample.conv.weight" + ] + new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.bias"] = vae_state_dict[ + f"decoder.up.{block_id}.upsample.conv.bias" + ] + + paths = renew_vae_resnet_paths(resnets) + meta_path = {"old": f"up.{block_id}.block", "new": f"up_blocks.{i}.resnets"} + assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) + + mid_resnets = [key for key in vae_state_dict if "decoder.mid.block" in key] + num_mid_res_blocks = 2 + for i in range(1, num_mid_res_blocks + 1): + resnets = [key for key in mid_resnets if f"decoder.mid.block_{i}" in key] + + paths = renew_vae_resnet_paths(resnets) + meta_path = {"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"} + assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) + + mid_attentions = [key for key in vae_state_dict if "decoder.mid.attn" in key] + paths = renew_vae_attention_paths(mid_attentions) + meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"} + assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) + conv_attn_to_linear(new_checkpoint) + return new_checkpoint + +# Reference from : https://github.com/huggingface/diffusers/blob/main/scripts/convert_vae_pt_to_diffusers.py +def vae_pt_to_vae_diffuser( + checkpoint_path: str, + output_path: str, +): + # Only support V1 + r = requests.get( + " https://raw.githubusercontent.com/CompVis/stable-diffusion/main/configs/stable-diffusion/v1-inference.yaml" + ) + io_obj = io.BytesIO(r.content) + + original_config = OmegaConf.load(io_obj) + image_size = 512 + device = "cuda" if torch.cuda.is_available() else "cpu" + if checkpoint_path.endswith("safetensors"): + from safetensors import safe_open + + checkpoint = {} + with safe_open(checkpoint_path, framework="pt", device="cpu") as f: + for key in f.keys(): + checkpoint[key] = f.get_tensor(key) + else: + checkpoint = torch.load(checkpoint_path, map_location=device)["state_dict"] + + # Convert the VAE model. + vae_config = create_vae_diffusers_config(original_config, image_size=image_size) + converted_vae_checkpoint = custom_convert_ldm_vae_checkpoint(checkpoint, vae_config) + + vae = AutoencoderKL(**vae_config) + vae.load_state_dict(converted_vae_checkpoint) + vae.save_pretrained(output_path) + + +def convert_images_to_tensors(images: list[Image.Image]): + return torch.stack([np.transpose(ToTensor()(image), (1, 2, 0)) for image in images]) + +def convert_tensors_to_images(images: torch.tensor): + return [Image.fromarray(np.clip(255. * image.cpu().numpy(), 0, 255).astype(np.uint8)) for image in images] + +def resize_images(images: list[Image.Image], size: tuple[int, int]): + return [image.resize(size) for image in images] \ No newline at end of file diff --git a/comfyui-hydit/workflow/hunyuan_diffusers_api.json b/comfyui-hydit/workflow/hunyuan_diffusers_api.json new file mode 100644 index 0000000..bc610c2 --- /dev/null +++ b/comfyui-hydit/workflow/hunyuan_diffusers_api.json @@ -0,0 +1,87 @@ +{ + "6": { + "inputs": { + "images": [ + "18", + 0 + ] + }, + "class_type": "PreviewImage", + "_meta": { + "title": "Preview Image" + } + }, + "15": { + "inputs": { + "scheduler_name": "ddim" + }, + "class_type": "DiffusersSchedulerLoader", + "_meta": { + "title": "HunYuan Scheduler Loader" + } + }, + "16": { + "inputs": { + "pipeline": [ + "21", + 0 + ], + "scheduler": [ + "15", + 0 + ] + }, + "class_type": "DiffusersModelMakeup", + "_meta": { + "title": "HunYuan Model Makeup" + } + }, + "18": { + "inputs": { + "batch_size": 1, + "width": 1024, + "height": 1024, + "steps": 30, + "cfg": 6, + "seed": 8806508, + "maked_pipeline": [ + "16", + 0 + ], + "positive": [ + "19", + 0 + ], + "negative": [ + "19", + 1 + ] + }, + "class_type": "DiffusersSampler", + "_meta": { + "title": "HunYuan Sampler" + } + }, + "19": { + "inputs": { + "positive": "描绘的风格是写实,画面主要描述一双泥泞的靴子在雨天里,靴子颜色是棕色,泥沙溅在 Boot 的表面,背景是湿漉漉的地板,泥泞的环境,天气是阴沉的雨天,镜头是近景", + "negative": "错误的眼睛,糟糕的人脸,毁容,糟糕的艺术,变形,多余的肢体,模糊的颜色,模糊,重复,病态,残缺," + }, + "class_type": "DiffusersClipTextEncode", + "_meta": { + "title": "HunYuan Clip Text Encode" + } + }, + "21": { + "inputs": { + "pipeline_folder_name": "ckpts", + "model_name": "pytorch_model_ema.pt", + "vae_name": "disable", + "backend": "diffusers" + }, + "class_type": "DiffusersPipelineLoader", + "_meta": { + "title": "HunYuan Pipeline Loader" + } + } +} \ No newline at end of file diff --git a/comfyui-hydit/workflow/hunyuan_ksampler_api.json b/comfyui-hydit/workflow/hunyuan_ksampler_api.json new file mode 100644 index 0000000..116e3c3 --- /dev/null +++ b/comfyui-hydit/workflow/hunyuan_ksampler_api.json @@ -0,0 +1,109 @@ +{ + "21": { + "inputs": { + "pipeline_folder_name": "ckpts", + "model_name": "pytorch_model_ema.pt", + "vae_name": "disable", + "backend": "ksampler" + }, + "class_type": "DiffusersPipelineLoader", + "_meta": { + "title": "HunYuan Pipeline Loader" + } + }, + "24": { + "inputs": { + "seed": 8806508, + "steps": 30, + "cfg": 6, + "sampler_name": "ddim", + "scheduler": "ddim_uniform", + "denoise": 1, + "model": [ + "21", + 1 + ], + "positive": [ + "25", + 0 + ], + "negative": [ + "26", + 0 + ], + "latent_image": [ + "27", + 0 + ] + }, + "class_type": "KSampler", + "_meta": { + "title": "KSampler" + } + }, + "25": { + "inputs": { + "text": "描绘的风格是写实,画面主要描述一双泥泞的靴子在雨天里,靴子颜色是棕色,泥沙溅在 Boot 的表面,背景是湿漉漉的地板,泥泞的环境,天气是阴沉的雨天,镜头是近景", + "clip": [ + "21", + 2 + ] + }, + "class_type": "CLIPTextEncode", + "_meta": { + "title": "CLIP Text Encode (Prompt)" + } + }, + "26": { + "inputs": { + "text": "错误的眼睛,糟糕的人脸,毁容,糟糕的艺术,变形,多余的肢体,模糊的颜色,模糊,重复,病态,残缺,", + "clip": [ + "21", + 2 + ] + }, + "class_type": "CLIPTextEncode", + "_meta": { + "title": "CLIP Text Encode (Prompt)" + } + }, + "27": { + "inputs": { + "width": 1024, + "height": 1024, + "batch_size": 1 + }, + "class_type": "EmptyLatentImage", + "_meta": { + "title": "Empty Latent Image" + } + }, + "28": { + "inputs": { + "samples": [ + "24", + 0 + ], + "vae": [ + "21", + 3 + ] + }, + "class_type": "VAEDecode", + "_meta": { + "title": "VAE Decode" + } + }, + "29": { + "inputs": { + "images": [ + "28", + 0 + ] + }, + "class_type": "PreviewImage", + "_meta": { + "title": "Preview Image" + } + } +} \ No newline at end of file