1
- # Masked Diffusion Transformer
1
+ # Masked Diffusion Transformer V2
2
2
3
3
[ ![ PWC] ( https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/masked-diffusion-transformer-is-a-strong/image-generation-on-imagenet-256x256 )] ( https://paperswithcode.com/sota/image-generation-on-imagenet-256x256?p=masked-diffusion-transformer-is-a-strong )
4
4
[ ![ HuggingFace space] ( https://img.shields.io/badge/🤗-HuggingFace%20Space-cyan.svg )] ( https://huggingface.co/spaces/shgao/MDT )
5
5
6
6
The official codebase for [ Masked Diffusion Transformer is a Strong Image Synthesizer] ( https://arxiv.org/abs/2303.14389 ) .
7
7
8
+ ## MDTv2: Faster Convergeence & Stronger performance
9
+ ** MDTv2 demonstrates new SOTA (State of the Art) performance and a 5x acceleration compared to the original MDT.**
10
+
8
11
## Introduction
9
12
10
13
Despite its success in image synthesis, we observe that diffusion probabilistic models (DPMs) often lack contextual reasoning ability to learn the relations among object parts in an image, leading to a slow learning process.
@@ -20,6 +23,7 @@ Experimental results show that MDT achieves superior image synthesis performance
20
23
| Model| Dataset | Resolution | FID-50K | Inception Score |
21
24
| ---------| ----------| -----------| ---------| --------|
22
25
| MDT-XL/2 | ImageNet | 256x256 | 1.79 | 283.01|
26
+ | MDTv2-XL/2 | ImageNet | 256x256 | 1.58 | 314.73|
23
27
24
28
[ Pretrained model download] ( https://huggingface.co/shgao/MDT-XL2/tree/main )
25
29
@@ -53,10 +57,10 @@ as the [ADM's dataloder](https://github.com/openai/guided-diffusion) gets the cl
53
57
<summary >Training on one node (`run.sh`). </summary >
54
58
55
59
``` shell
56
- export OPENAI_LOGDIR=output_mdt_s2
60
+ export OPENAI_LOGDIR=output_mdtv2_s2
57
61
NUM_GPUS=8
58
62
59
- MODEL_FLAGS=" --image_size 256 --mask_ratio 0.30 --decode_layer 2 --model MDT_S_2 "
63
+ MODEL_FLAGS=" --image_size 256 --mask_ratio 0.30 --decode_layer 4 --model MDTv2_S_2 "
60
64
DIFFUSION_FLAGS=" --diffusion_steps 1000"
61
65
TRAIN_FLAGS=" --batch_size 32"
62
66
DATA_PATH=/dataset/imagenet
@@ -71,8 +75,8 @@ python -m torch.distributed.launch --nproc_per_node=$NUM_GPUS scripts/image_trai
71
75
72
76
``` shell
73
77
# On master:
74
- export OPENAI_LOGDIR=output_mdt_xl2
75
- MODEL_FLAGS=" --image_size 256 --mask_ratio 0.30 --decode_layer 2 --model MDT_XL_2 "
78
+ export OPENAI_LOGDIR=output_mdtv2_xl2
79
+ MODEL_FLAGS=" --image_size 256 --mask_ratio 0.30 --decode_layer 2 --model MDTv2_XL_2 "
76
80
DIFFUSION_FLAGS=" --diffusion_steps 1000"
77
81
TRAIN_FLAGS=" --batch_size 4"
78
82
DATA_PATH=/dataset/imagenet
@@ -82,8 +86,8 @@ GPU_PRE_NODE=8
82
86
python -m torch.distributed.launch --master_addr=$( hostname) --nnodes=$NUM_NODE --node_rank=$RANK --nproc_per_node=$GPU_PRE_NODE --master_port=$MASTER_PORT scripts/image_train.py --data_dir $DATA_PATH $MODEL_FLAGS $DIFFUSION_FLAGS $TRAIN_FLAGS
83
87
84
88
# On workers:
85
- export OPENAI_LOGDIR=output_mdt_xl2
86
- MODEL_FLAGS=" --image_size 256 --mask_ratio 0.30 --decode_layer 2 --model MDT_XL_2 "
89
+ export OPENAI_LOGDIR=output_mdtv2_xl2
90
+ MODEL_FLAGS=" --image_size 256 --mask_ratio 0.30 --decode_layer 2 --model MDTv2_XL_2 "
87
91
DIFFUSION_FLAGS=" --diffusion_steps 1000"
88
92
TRAIN_FLAGS=" --batch_size 4"
89
93
DATA_PATH=/dataset/imagenet
@@ -106,12 +110,12 @@ Please follow the instructions in the `evaluations` folder to set up the evaluat
106
110
<summary >Sampling and Evaluation (`run_sample.sh`): </summary >
107
111
108
112
``` shell
109
- MODEL_PATH=output_mdt_xl2/mdt_xl2_v1_ckpt .pt
110
- export OPENAI_LOGDIR=output_mdt_xl2_eval
113
+ MODEL_PATH=output_mdtv2_xl2/mdt_xl2_v2_ckpt .pt
114
+ export OPENAI_LOGDIR=output_mdtv2_xl2_eval
111
115
NUM_GPUS=8
112
116
113
117
echo ' CFG Class-conditional sampling:'
114
- MODEL_FLAGS=" --image_size 256 --model MDT_XL_2 --decode_layer 2 "
118
+ MODEL_FLAGS=" --image_size 256 --model MDTv2_XL_2 --decode_layer 4 "
115
119
DIFFUSION_FLAGS=" --num_sampling_steps 250 --num_samples 50000 --cfg_cond True"
116
120
echo $MODEL_FLAGS
117
121
echo $DIFFUSION_FLAGS
@@ -123,7 +127,7 @@ echo $MODEL_PATH
123
127
python evaluations/evaluator.py ../dataeval/VIRTUAL_imagenet256_labeled.npz $OPENAI_LOGDIR /samples_50000x256x256x3.npz
124
128
125
129
echo ' Class-conditional sampling:'
126
- MODEL_FLAGS=" --image_size 256 --model MDT_XL_2 --decode_layer 2 "
130
+ MODEL_FLAGS=" --image_size 256 --model MDTv2_XL_2 --decode_layer 4 "
127
131
DIFFUSION_FLAGS=" --num_sampling_steps 250 --num_samples 50000"
128
132
echo $MODEL_FLAGS
129
133
echo $DIFFUSION_FLAGS
0 commit comments