Skip to content

Commit 75ee9a9

Browse files
committed
mdtv2
1 parent 71a0f94 commit 75ee9a9

File tree

11 files changed

+177
-125
lines changed

11 files changed

+177
-125
lines changed

README.md

Lines changed: 15 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,13 @@
1-
# Masked Diffusion Transformer
1+
# Masked Diffusion Transformer V2
22

33
[![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/masked-diffusion-transformer-is-a-strong/image-generation-on-imagenet-256x256)](https://paperswithcode.com/sota/image-generation-on-imagenet-256x256?p=masked-diffusion-transformer-is-a-strong)
44
[![HuggingFace space](https://img.shields.io/badge/🤗-HuggingFace%20Space-cyan.svg)](https://huggingface.co/spaces/shgao/MDT)
55

66
The official codebase for [Masked Diffusion Transformer is a Strong Image Synthesizer](https://arxiv.org/abs/2303.14389).
77

8+
## MDTv2: Faster Convergeence & Stronger performance
9+
**MDTv2 demonstrates new SOTA (State of the Art) performance and a 5x acceleration compared to the original MDT.**
10+
811
## Introduction
912

1013
Despite its success in image synthesis, we observe that diffusion probabilistic models (DPMs) often lack contextual reasoning ability to learn the relations among object parts in an image, leading to a slow learning process.
@@ -20,6 +23,7 @@ Experimental results show that MDT achieves superior image synthesis performance
2023
| Model| Dataset | Resolution | FID-50K | Inception Score |
2124
|---------|----------|-----------|---------|--------|
2225
|MDT-XL/2 | ImageNet | 256x256 | 1.79 | 283.01|
26+
|MDTv2-XL/2 | ImageNet | 256x256 | 1.58 | 314.73|
2327

2428
[Pretrained model download](https://huggingface.co/shgao/MDT-XL2/tree/main)
2529

@@ -53,10 +57,10 @@ as the [ADM's dataloder](https://github.com/openai/guided-diffusion) gets the cl
5357
<summary>Training on one node (`run.sh`). </summary>
5458

5559
```shell
56-
export OPENAI_LOGDIR=output_mdt_s2
60+
export OPENAI_LOGDIR=output_mdtv2_s2
5761
NUM_GPUS=8
5862

59-
MODEL_FLAGS="--image_size 256 --mask_ratio 0.30 --decode_layer 2 --model MDT_S_2"
63+
MODEL_FLAGS="--image_size 256 --mask_ratio 0.30 --decode_layer 4 --model MDTv2_S_2"
6064
DIFFUSION_FLAGS="--diffusion_steps 1000"
6165
TRAIN_FLAGS="--batch_size 32"
6266
DATA_PATH=/dataset/imagenet
@@ -71,8 +75,8 @@ python -m torch.distributed.launch --nproc_per_node=$NUM_GPUS scripts/image_trai
7175

7276
```shell
7377
# On master:
74-
export OPENAI_LOGDIR=output_mdt_xl2
75-
MODEL_FLAGS="--image_size 256 --mask_ratio 0.30 --decode_layer 2 --model MDT_XL_2"
78+
export OPENAI_LOGDIR=output_mdtv2_xl2
79+
MODEL_FLAGS="--image_size 256 --mask_ratio 0.30 --decode_layer 2 --model MDTv2_XL_2"
7680
DIFFUSION_FLAGS="--diffusion_steps 1000"
7781
TRAIN_FLAGS="--batch_size 4"
7882
DATA_PATH=/dataset/imagenet
@@ -82,8 +86,8 @@ GPU_PRE_NODE=8
8286
python -m torch.distributed.launch --master_addr=$(hostname) --nnodes=$NUM_NODE --node_rank=$RANK --nproc_per_node=$GPU_PRE_NODE --master_port=$MASTER_PORT scripts/image_train.py --data_dir $DATA_PATH $MODEL_FLAGS $DIFFUSION_FLAGS $TRAIN_FLAGS
8387

8488
# On workers:
85-
export OPENAI_LOGDIR=output_mdt_xl2
86-
MODEL_FLAGS="--image_size 256 --mask_ratio 0.30 --decode_layer 2 --model MDT_XL_2"
89+
export OPENAI_LOGDIR=output_mdtv2_xl2
90+
MODEL_FLAGS="--image_size 256 --mask_ratio 0.30 --decode_layer 2 --model MDTv2_XL_2"
8791
DIFFUSION_FLAGS="--diffusion_steps 1000"
8892
TRAIN_FLAGS="--batch_size 4"
8993
DATA_PATH=/dataset/imagenet
@@ -106,12 +110,12 @@ Please follow the instructions in the `evaluations` folder to set up the evaluat
106110
<summary>Sampling and Evaluation (`run_sample.sh`): </summary>
107111

108112
```shell
109-
MODEL_PATH=output_mdt_xl2/mdt_xl2_v1_ckpt.pt
110-
export OPENAI_LOGDIR=output_mdt_xl2_eval
113+
MODEL_PATH=output_mdtv2_xl2/mdt_xl2_v2_ckpt.pt
114+
export OPENAI_LOGDIR=output_mdtv2_xl2_eval
111115
NUM_GPUS=8
112116

113117
echo 'CFG Class-conditional sampling:'
114-
MODEL_FLAGS="--image_size 256 --model MDT_XL_2 --decode_layer 2"
118+
MODEL_FLAGS="--image_size 256 --model MDTv2_XL_2 --decode_layer 4"
115119
DIFFUSION_FLAGS="--num_sampling_steps 250 --num_samples 50000 --cfg_cond True"
116120
echo $MODEL_FLAGS
117121
echo $DIFFUSION_FLAGS
@@ -123,7 +127,7 @@ echo $MODEL_PATH
123127
python evaluations/evaluator.py ../dataeval/VIRTUAL_imagenet256_labeled.npz $OPENAI_LOGDIR/samples_50000x256x256x3.npz
124128

125129
echo 'Class-conditional sampling:'
126-
MODEL_FLAGS="--image_size 256 --model MDT_XL_2 --decode_layer 2"
130+
MODEL_FLAGS="--image_size 256 --model MDTv2_XL_2 --decode_layer 4"
127131
DIFFUSION_FLAGS="--num_sampling_steps 250 --num_samples 50000"
128132
echo $MODEL_FLAGS
129133
echo $DIFFUSION_FLAGS

infer_mdt.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -8,23 +8,23 @@
88
from torchvision.utils import save_image
99
from masked_diffusion import create_diffusion
1010
from diffusers.models import AutoencoderKL
11-
from masked_diffusion.models import MDT_XL_2
11+
from masked_diffusion.models import MDTv2_XL_2
1212

1313

1414
# Setup PyTorch:
15-
torch.manual_seed(0)
15+
torch.manual_seed(1)
1616
torch.set_grad_enabled(False)
1717
device = "cuda" if torch.cuda.is_available() else "cpu"
18-
num_sampling_steps = 500
19-
cfg_scale = 5.0
18+
num_sampling_steps = 250
19+
cfg_scale = 4.0
2020
pow_scale = 0.01 # large pow_scale increase the diversity, small pow_scale increase the quality.
21-
model_path = 'mdt_xl2_v1_ckpt.pt'
21+
model_path = 'mdt_xl2_v2_ckpt.pt'
2222

2323
# Load model:
2424
image_size = 256
2525
assert image_size in [256], "We provide pre-trained models for 256x256 resolutions for now."
2626
latent_size = image_size // 8
27-
model = MDT_XL_2(input_size=latent_size, decode_layer=2).to(device)
27+
model = MDTv2_XL_2(input_size=latent_size, decode_layer=2).to(device)
2828

2929
state_dict = torch.load(model_path, map_location=lambda storage, loc: storage)
3030
model.load_state_dict(state_dict)
@@ -33,7 +33,7 @@
3333
vae = AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-mse").to(device)
3434

3535
# Labels to condition the model with:
36-
class_labels = [208]*3
36+
class_labels = [19,23,106,108,278,282]
3737

3838
# Create sampling noise:
3939
n = len(class_labels)

masked_diffusion/gaussian_diffusion.py

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ class ModelMeanType(enum.Enum):
2828
PREVIOUS_X = enum.auto() # the model predicts x_{t-1}
2929
START_X = enum.auto() # the model predicts x_0
3030
EPSILON = enum.auto() # the model predicts epsilon
31+
VELOCITY = enum.auto() # the model predicts v
3132

3233

3334
class ModelVarType(enum.Enum):
@@ -732,6 +733,26 @@ def training_losses(self, model, x_start, t, model_kwargs=None, noise=None):
732733

733734
terms = {}
734735

736+
737+
mse_loss_weight = None
738+
alpha = _extract_into_tensor(self.sqrt_alphas_cumprod, t, t.shape)
739+
sigma = _extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, t.shape)
740+
snr = (alpha / sigma) ** 2
741+
742+
velocity = (alpha[:, None, None, None] * x_t - x_start) / sigma[:, None, None, None]
743+
744+
# get loss weight
745+
if self.model_mean_type is not ModelMeanType.START_X:
746+
mse_loss_weight = th.ones_like(t)
747+
k = 5.0
748+
# min{snr, k}
749+
mse_loss_weight = th.stack([snr, k * th.ones_like(t)], dim=1).min(dim=1)[0] / snr
750+
else:
751+
k = 5.0
752+
# min{snr, k}
753+
mse_loss_weight = th.stack([snr, k * th.ones_like(t)], dim=1).min(dim=1)[0]
754+
755+
735756
if self.loss_type == LossType.KL or self.loss_type == LossType.RESCALED_KL:
736757
terms["loss"] = self._vb_terms_bpd(
737758
model=model,
@@ -774,9 +795,10 @@ def training_losses(self, model, x_start, t, model_kwargs=None, noise=None):
774795
)[0],
775796
ModelMeanType.START_X: x_start,
776797
ModelMeanType.EPSILON: noise,
798+
ModelMeanType.VELOCITY: velocity,
777799
}[self.model_mean_type]
778800
assert model_output.shape == target.shape == x_start.shape
779-
terms["mse"] = mean_flat((target - model_output) ** 2)
801+
terms["mse"] = mse_loss_weight * mean_flat((target - model_output) ** 2)
780802
if "vb" in terms:
781803
terms["loss"] = terms["mse"] + terms["vb"]
782804
else:

0 commit comments

Comments
 (0)