Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Optimized the training startup process(【待解决问题5】优化混元DiT模型启动配置流程) #156

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 35 additions & 1 deletion hydit/config.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import os
import argparse
import importlib.util

from .constants import *
from .modules.models import HUNYUAN_DIT_CONFIG, HUNYUAN_DIT_MODELS
Expand All @@ -12,6 +14,33 @@ def model_var_type(value):
except KeyError:
raise ValueError(f"Invalid choice '{value}', valid choices are {[v.name for v in ModelVarType]}")

def load_config(training_type):
config_path = "./hydit/config/train_config.py" if training_type == 'full' \
else f"./hydit/config/train_{training_type}_config.py"
config_path = os.path.normpath(config_path)
spec = importlib.util.spec_from_file_location("train_config", config_path)
config = importlib.util.module_from_spec(spec)
spec.loader.exec_module(config)
return config.train_config


def merge_args_with_config(args, config, parser):
def is_default_or_none_value(parser, key, value):
for action in parser._actions:
if action.dest == key:
return value is None or action.default == value
return False

for key, value in config.items():
if isinstance(value, dict):
for sub_key, sub_value in value.items():
if is_default_or_none_value(parser, sub_key, getattr(args, key, None)):
setattr(args, sub_key, sub_value)
else:
if is_default_or_none_value(parser, key, getattr(args, key, None)):
setattr(args, key, value)
return args


def get_args(default_args=None):
parser = argparse.ArgumentParser()
Expand All @@ -32,7 +61,7 @@ def get_args(default_args=None):
'(value, value).')
parser.add_argument("--qk-norm", action="store_true", help="Query Key normalization. See http://arxiv.org/abs/2302.05442 for details.")
parser.set_defaults(qk_norm=True)
parser.add_argument("--norm", type=str, choices=["rms", "laryer"], default="layer", help="Normalization layer type")
parser.add_argument("--norm", type=str, choices=["rms", "layer"], default="layer", help="Normalization layer type")
parser.add_argument("--text-states-dim", type=int, default=1024, help="Hidden size of CLIP text encoder.")
parser.add_argument("--text-len", type=int, default=77, help="Token length of CLIP text encoder output.")
parser.add_argument("--text-states-dim-t5", type=int, default=2048, help="Hidden size of CLIP text encoder.")
Expand Down Expand Up @@ -194,6 +223,11 @@ def get_args(default_args=None):
parser.add_argument('--zero-stage', type=int, default=1)
parser.add_argument("--async-ema", action="store_true", help="Whether to use multi stream to excut EMA.")

parser.add_argument("--training-type", type=str, required=True, default='full', choices=['full', 'controlnet', 'lora'],
help="Specify the type of training")
args = parser.parse_args(default_args)

config = load_config(args.training_type)
args = merge_args_with_config(args, config, parser)

return args
77 changes: 77 additions & 0 deletions hydit/config/train_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
train_config = {
'task_flag': 'dit_g2_full_1024p', # the task flag is used to identify folders
'model_config': {
'model': 'DiT-g/2', # choices = ["DiT-g/2", "DiT-XL/2"]
'image_size': [1024, 1024], # training image resolution
'qk_norm': True, # Query Key normalization
'norm': 'layer', # normalization layer type, choices=["rms", "layer"]
'text_states_dim': 1024, # hidden size of CLIP text encoder
'text_len': 77, # token length of CLIP text encoder output
'text_states_dim_t5': 2048, # hidden size of CLIP text encoder
'text_len_t5': 256, # token length of T5 text encoder output
'learn_sigma': False, # learn extra channels for sigma
'predict_type': 'v_prediction', # choices = ["epsilon", "sample", "v_prediction"]
'noise_schedule': 'scaled_linear', # choices = ["linear", "scaled_linear", "squaredcos_cap_v2"]
'beta_start': 0.00085, # beta start value
'beta_end': 0.03, # beta end value
'sigma_small': False, # if True, use a smaller fixed sigma otherwise a larger one
'mse_loss_weight_type': 'constant', # Min-SNR-gamma, choices = ['constant', 'min_snr_<gamma>'(gamma is a integer)]
'model_var_type': None, # specify the model variable type
'noise_offset': 0.0 # add extra noise to the input image
},
'dataset_config': {
'batch_size': 1, # per-GPU batch size
'seed': 42, # a seed for all the prompts
'index_file':
'dataset/porcelain/jsons/porcelain.json', # index file for dataloader
'random_flip': True, # random flip image
'reset_loader': False, # reset the data loader
'multireso': False, # use multi-resolution training
'reso_step': None, # step size for multi-resolution training
'random_shrink_size_cond': False, # randomly shrink the original size condition
'merge_src_cond': False # merge the source condition into a single value
},
'training_config': {
'lr': 0.0001, # learning rate
'epochs': 1400, # training epochs
'max_training_steps': 10000000, # max training steps
'gc_interval': 40, # frequency (in steps) to invoke gc.collect()
'log_every': 10, # frequency (in steps) to log training progress
'ckpt_every': 10000, # frequency (in steps) to create a ckpt
'ckpt_latest_every': 5000, # frequency (in steps) to create a ckpt named `latest.pt`
'num_workers': 4, # number of workers for data loading
'global_seed': 999, # global random seed
'warmup_min_lr': 0.000001, # minimum learning rate during warmup
'warmup_num_steps': 0, # number of steps to warm up the learning rate
'weight_decay': 0, # weight-decay in optimizer
'rope_img': 'base512', # extend or interpolate the positional embedding of the image, choices = ['extend', 'base512', 'base1024']
'rope_real': True, # use real part and imaginary part separately for RoPE
'uncond_p': 0.44, # the probability of dropping training text used for CLIP feature extraction
'uncond_p_t5': 0.44, # the probability of dropping training text used for mT5 feature extraction
'results_dir': './log_EXP', # save root for results
'resume': './ckpts/t2i/model/', # resume experiment from a checkpoint
'strict': True, # strict loading of checkpoint
'resume_deepspeed': False, # resume model and ema states from a checkpoint saved by Deepspeed version of DIT
'resume_split': False, # resume model and ema states from two checkpoint separated from DeepSpeed ckpt
'ema_to_module': False, # if true, initialize the module with EMA weights
'module_to_ema': False, # if true, initialize the ema with Module weights
'use_ema': True, # use EMA model
'ema_dtype': 'fp32', # choices = ['fp16', 'fp32', 'none']. if none, use the same data type as the model
'ema_decay': None, # EMA decay rate. If None, use the default value of the model
'ema_warmup': False, # EMA warmup. If True, perform ema_decay warmup from 0 to ema_decay
'ema_warmup_power': None, # EMA power. If None, use the default value of the model
'ema_reset_decay': False, # reset EMA decay to 0 and restart increasing the EMA decay
'use_flash_attn': True, # use flash attention to accelerate training
'use_zero_stage': 2, # use AngelPTM zero stage. choices = [1, 2, 3]
'grad_accu_steps': 1, # gradient accumulation steps
'use_fp16': True, # use FP16 precision
'extra_fp16': False # use extra fp16 for vae and text_encoder
},
'deepspeed_config': {
'local_rank': None, # local rank passed from distributed launcher.
'deepspeed_optimizer': True, # switching to the optimizers in DeepSpeed.
'remote_device': 'none', # remote device for ZeRO-3 initialized parameters. choices = ['none', 'cpu', 'nvme'].
'zero_stage': 1, # ZeRO optimization stage.
'async_ema': False # whether to use multi-stream to execute EMA.
}
}
82 changes: 82 additions & 0 deletions hydit/config/train_controlnet_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
train_config = {
'task_flag': 'canny_controlnet', # the task flag is used to identify folders
'model_config': {
'model': 'DiT-g/2', # choices = ["DiT-g/2", "DiT-XL/2"]
'image_size': [1024, 1024], # training image resolution
'qk_norm': True, # Query Key normalization
'norm': 'layer', # normalization layer type, choices=["rms", "layer"]
'text_states_dim': 1024, # hidden size of CLIP text encoder
'text_len': 77, # token length of CLIP text encoder output
'text_states_dim_t5': 2048, # hidden size of CLIP text encoder
'text_len_t5': 256, # token length of T5 text encoder output
'learn_sigma': False, # learn extra channels for sigma
'predict_type': 'v_prediction', # choices = ["epsilon", "sample", "v_prediction"]
'noise_schedule': 'scaled_linear', # choices = ["linear", "scaled_linear", "squaredcos_cap_v2"]
'beta_start': 0.00085, # beta start value
'beta_end': 0.03, # beta end value
'sigma_small': False, # if True, use a smaller fixed sigma otherwise a larger one
'mse_loss_weight_type': 'constant', # Min-SNR-gamma, choices = ['constant', 'min_snr_<gamma>'(gamma is a integer)]
'model_var_type': None, # specify the model variable type
'noise_offset': 0.0 # add extra noise to the input image
},
'dataset_config': {
'batch_size': 1, # per-GPU batch size
'seed': 42, # a seed for all the prompts
'index_file': '/path/to/your/indexfile', # index file for dataloader
'random_flip': True, # random flip image
'reset_loader': False, # reset the data loader
'multireso': True, # use multi-resolution training
'reso_step': 64, # step size for multi-resolution training
'random_shrink_size_cond': False, # randomly shrink the original size condition
'merge_src_cond': False # merge the source condition into a single value
},
'training_config': {
'lr': 0.0001, # learning rate
'epochs': 1400, # training epochs
'max_training_steps': 10000000, # max training steps
'gc_interval': 40, # frequency (in steps) to invoke gc.collect()
'log_every': 10, # frequency (in steps) to log training progress
'ckpt_every': 10000, # frequency (in steps) to create a ckpt
'ckpt_latest_every': 5000, # frequency (in steps) to create a ckpt named `latest.pt`
'num_workers': 4, # number of workers for data loading
'global_seed': 999, # global random seed
'warmup_min_lr': 0.000001, # minimum learning rate during warmup
'warmup_num_steps': 0, # number of steps to warm up the learning rate
'weight_decay': 0, # weight-decay in optimizer
'rope_img': 'base512', # extend or interpolate the positional embedding of the image, choices = ['extend', 'base512', 'base1024']
'rope_real': True, # use real part and imaginary part separately for RoPE
'uncond_p': 0.44, # the probability of dropping training text used for CLIP feature extraction
'uncond_p_t5': 0.44, # the probability of dropping training text used for mT5 feature extraction
'results_dir': './log_EXP', # save root for results
'resume': './ckpts/t2i/model/', # resume experiment from a checkpoint
'strict': True, # strict loading of checkpoint
'resume_deepspeed': False, # resume model and ema states from a checkpoint saved by Deepspeed version of DIT
'resume_split': True, # resume model and ema states from two checkpoint separated from DeepSpeed ckpt
'ema_to_module': True, # if true, initialize the module with EMA weights
'module_to_ema': False, # if true, initialize the ema with Module weights
'use_ema': True, # use EMA model
'ema_dtype': 'fp32', # choices = ['fp16', 'fp32', 'none']. if none, use the same data type as the model
'ema_decay': None, # EMA decay rate. If None, use the default value of the model
'ema_warmup': False, # EMA warmup. If True, perform ema_decay warmup from 0 to ema_decay
'ema_warmup_power': None, # EMA power. If None, use the default value of the model
'ema_reset_decay': False, # reset EMA decay to 0 and restart increasing the EMA decay
'use_flash_attn': True, # use flash attention to accelerate training
'use_zero_stage': 2, # use AngelPTM zero stage. choices = [1, 2, 3]
'grad_accu_steps': 2, # gradient accumulation steps
'use_fp16': True, # use FP16 precision
'extra_fp16': False # use extra fp16 for vae and text_encoder
},
'controlnet_config': {
'control_type': 'canny', # Controlnet condition type, choices=['canny', 'depth', 'pose']
'control_weight': 1.0,
# Controlnet weight, You can use a float to specify the weight for all layers, or use a list to separately specify the weight for each layer, for example, '[1.0 * (0.825 ** float(19 - i)) for i in range(19)]
'condition_image_path': None # inference condition image path
},
'deepspeed_config': {
'local_rank': None, # local rank passed from distributed launcher.
'deepspeed_optimizer': True, # switching to the optimizers in DeepSpeed.
'remote_device': 'none', # remote device for ZeRO-3 initialized parameters. choices = ['none', 'cpu', 'nvme'].
'zero_stage': 1, # ZeRO optimization stage.
'async_ema': False # whether to use multi-stream to execute EMA.
}
}
Loading