diff --git a/hydit/config.py b/hydit/config.py index 013d9ce..4cd08d8 100644 --- a/hydit/config.py +++ b/hydit/config.py @@ -4,7 +4,7 @@ from .modules.models import HUNYUAN_DIT_CONFIG, HUNYUAN_DIT_MODELS from .diffusion.gaussian_diffusion import ModelVarType -import deepspeed +# import deepspeed def model_var_type(value): try: @@ -182,15 +182,32 @@ def get_args(default_args=None): # ======================================================================================================== # Deepspeed config # ======================================================================================================== - parser = deepspeed.add_config_arguments(parser) - parser.add_argument('--local_rank', type=int, default=None, - help='local rank passed from distributed launcher.') - parser.add_argument('--deepspeed-optimizer', action='store_true', - help='Switching to the optimizers in DeepSpeed') - parser.add_argument('--remote-device', type=str, default='none', choices=['none', 'cpu', 'nvme'], - help='Remote device for ZeRO-3 initialized parameters.') - parser.add_argument('--zero-stage', type=int, default=1) - parser.add_argument("--async-ema", action="store_true", help="Whether to use multi stream to excut EMA.") + # parser = deepspeed.add_config_arguments(parser) + # parser.add_argument('--local_rank', type=int, default=None, + # help='local rank passed from distributed launcher.') + # parser.add_argument('--deepspeed-optimizer', action='store_true', + # help='Switching to the optimizers in DeepSpeed') + # parser.add_argument('--remote-device', type=str, default='none', choices=['none', 'cpu', 'nvme'], + # help='Remote device for ZeRO-3 initialized parameters.') + # parser.add_argument('--zero-stage', type=int, default=1) + # parser.add_argument("--async-ema", action="store_true", help="Whether to use multi stream to excut EMA.") + + # Attempt to import DeepSpeed and add its arguments if available + try: + import deepspeed + + # Add DeepSpeed-specific arguments + parser = deepspeed.add_config_arguments(parser) + parser.add_argument('--local_rank', type=int, default=None, + help='local rank passed from distributed launcher.') + parser.add_argument('--deepspeed-optimizer', action='store_true', + help='Switching to the optimizers in DeepSpeed') + parser.add_argument('--remote-device', type=str, default='none', choices=['none', 'cpu', 'nvme'], + help='Remote device for ZeRO-3 initialized parameters.') + parser.add_argument('--zero-stage', type=int, default=1) + parser.add_argument("--async-ema", action="store_true", help="Whether to use multi stream to excut EMA.") + except ImportError: + print("DeepSpeed not available. Skipping related arguments...") args = parser.parse_args(default_args) diff --git a/hydit/modules/models.py b/hydit/modules/models.py index 059f9bd..9d025a2 100644 --- a/hydit/modules/models.py +++ b/hydit/modules/models.py @@ -201,7 +201,7 @@ def __init__( self.text_len_t5 = args.text_len_t5 self.norm = args.norm - use_flash_attn = args.infer_mode == 'fa' or args.use_flash_attn + use_flash_attn = args.infer_mode == 'fa' or getattr(args, 'use_flash_attn', False) if use_flash_attn: log_fn(f" Enable Flash Attention.") qk_norm = args.qk_norm # See http://arxiv.org/abs/2302.05442 for details. diff --git a/sample_t2i.py b/sample_t2i.py index f017839..1464e36 100644 --- a/sample_t2i.py +++ b/sample_t2i.py @@ -1,12 +1,9 @@ from pathlib import Path - from loguru import logger - from dialoggen.dialoggen_demo import DialogGen from hydit.config import get_args from hydit.inference import End2End - def inferencer(): args = get_args() models_root_path = Path(args.model_root) @@ -26,6 +23,15 @@ def inferencer(): return args, gen, enhancer +def get_next_index(save_dir): + all_files = list(save_dir.glob('*.png')) + indices = [] + for f in all_files: + try: + indices.append(int(f.stem)) + except ValueError: + logger.warning(f"Skipping file with non-integer name: {f}") + return max(indices, default=-1) + 1 if __name__ == "__main__": args, gen, enhancer = inferencer() @@ -59,12 +65,9 @@ def inferencer(): # Save images save_dir = Path('results') save_dir.mkdir(exist_ok=True) + # Find the first available index - all_files = list(save_dir.glob('*.png')) - if all_files: - start = max([int(f.stem) for f in all_files]) + 1 - else: - start = 0 + start = get_next_index(save_dir) for idx, pil_img in enumerate(images): save_path = save_dir / f"{idx + start}.png"