-
Notifications
You must be signed in to change notification settings - Fork 190
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
ddpm #915
base: develop
Are you sure you want to change the base?
ddpm #915
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,64 @@ | ||
mode: train | ||
train: | ||
seed: 1234 | ||
output_dir: 'experiments/km256/' | ||
doc: 'weights/km256/' | ||
timesteps: 1000 | ||
log_path: '' | ||
verbose: 'info' | ||
ni: false | ||
comment: '' | ||
resume_training: false | ||
data: | ||
dataset: "kolmogorov flow" | ||
data_dir: "/home/aistudio/data/data264003/train_data.npy" | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 修改路径为ddpm下的相对路径,之后yaml文件中的其他路径也是 |
||
stat_path: "/home/aistudio/data/data264003/km256_stats.npz" | ||
image_size: 256 | ||
channels: 3 | ||
logit_transform: false | ||
uniform_dequantization: false | ||
gaussian_dequantization: false | ||
random_flip: false | ||
rescaled: false | ||
num_workers: 0 | ||
|
||
model: | ||
type: "conditional" | ||
in_channels: 3 | ||
out_ch: 3 | ||
ch: 64 | ||
ch_mult: [1, 1, 1, 2] | ||
num_res_blocks: 1 | ||
attn_resolutions: [16, ] | ||
dropout: 0.1 | ||
var_type: fixedlarge | ||
ema_rate: 0.9999 | ||
ema: True | ||
resamp_with_conv: True | ||
ckpt_path: '/home/aistudio/data/data264003/init_ckpt.pdparams' | ||
|
||
diffusion: | ||
beta_schedule: linear | ||
beta_start: 0.0001 | ||
beta_end: 0.02 | ||
num_diffusion_timesteps: 1000 # Might need to be changed to 500 later to match SDEdit | ||
|
||
training: | ||
batch_size: 1 | ||
n_epochs: 20 # 300 epoch for about 12 hours | ||
n_iters: 200000 | ||
snapshot_freq: 20000 | ||
validation_freq: 2000 | ||
|
||
sampling: | ||
batch_size: 32 | ||
last_only: True | ||
|
||
optim: | ||
weight_decay: 0.000 | ||
optimizer: "Adam" | ||
lr: 0.0002 | ||
beta1: 0.9 | ||
amsgrad: false | ||
eps: 0.00000001 | ||
grad_clip: 1.0 |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,50 @@ | ||
log_dir: "./experiments/kmflow_re1000_rs256_ddim_recons_conditional_log" | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 同上一个文件 |
||
mode: eval | ||
place: 'cuda' | ||
data: | ||
dataset: "kmflow" | ||
category: "kmflow" | ||
image_size: 256 | ||
channels: 3 | ||
num_workers: 0 | ||
data_dir: "/home/aistudio/data/data264003/kf_2d_re1000_256_40seed.npy" | ||
sample_data_dir: "/home/aistudio/data/data264003/kmflow_sampled_data_irregnew.npz" | ||
stat_path: "/home/aistudio/data/data264003/km256_stats.npz" | ||
data_kw: 'u3232' | ||
smoothing: True | ||
smoothing_scale: 7 | ||
|
||
model: | ||
type: "conditional" | ||
in_channels: 3 | ||
out_ch: 3 | ||
ch: 64 | ||
ch_mult: [ 1, 1, 1, 2 ] | ||
num_res_blocks: 1 | ||
attn_resolutions: [ 16, ] | ||
dropout: 0.0 | ||
var_type: fixedlarge | ||
ema_rate: 0.9999 | ||
ema: True | ||
resamp_with_conv: True | ||
ckpt_path: "/home/aistudio/data/data264003/latest.pdparams" | ||
|
||
diffusion: | ||
beta_schedule: linear | ||
beta_start: 0.0001 | ||
beta_end: 0.02 | ||
num_diffusion_timesteps: 1000 | ||
eval: | ||
repeat_run: 1 | ||
sample_step: 1 | ||
t: 240 | ||
reverse_steps: 30 | ||
seed: 1234 | ||
|
||
|
||
sampling: | ||
batch_size: 20 | ||
last_only: True | ||
guidance_weight: 0. | ||
log_loss: True | ||
dump_arr: True |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,201 @@ | ||
import paddle | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 请删除函数中的注释(除了特别需要解释的部分),复杂函数开头写上hint,注释需要使用英文、半角,修改后请使用pre-commit检查代码格式 |
||
|
||
|
||
def compute_alpha(beta, t): | ||
beta = paddle.concat(x=[paddle.zeros(shape=[1]).to("gpu"), beta], axis=0) | ||
beta_sub = 1 - beta | ||
|
||
# 2. 计算累积乘积 | ||
cumprod_beta = paddle.cumprod(beta_sub, dim=0) | ||
|
||
# 3. 根据t + 1的值选择特定的元素 | ||
# 注意:Paddle中使用gather函数来实现类似的索引选择功能 | ||
# 假设t是一个标量,我们需要将其转换为Tensor,如果t已经是Tensor则不需要转换 | ||
if not isinstance(t, paddle.Tensor): | ||
t = paddle.to_tensor(t, dtype="int64") | ||
selected = paddle.gather(cumprod_beta, index=t + 1, axis=0) | ||
|
||
# 4. 改变张量的形状 | ||
# PaddlePaddle使用reshape函数来改变形状 | ||
a = paddle.reshape(selected, shape=[-1, 1, 1, 1]) | ||
# """Class Method: *.view, can not convert, please check whether it is torch.Tensor.*/Optimizer.*/nn.Module.*/torch.distributions.Distribution.*/torch.autograd.function.FunctionCtx.*/torch.profiler.profile.*/torch.autograd.profiler.profile.*, and convert manually""" | ||
# >>>>>> a = (1 - beta).cumprod(dim=0).index_select(axis=0, index=t + 1).view(-1, | ||
# 1, 1, 1) | ||
return a | ||
|
||
|
||
def ddim_steps(x, seq, model, b, **kwargs): | ||
n = x.shape[0] | ||
seq_next = [-1] + list(seq[:-1]) | ||
x0_preds = [] | ||
xs = [x] | ||
dx_func = kwargs.get("dx_func", None) | ||
clamp_func = kwargs.get("clamp_func", None) | ||
cache = kwargs.get("cache", False) | ||
logger = kwargs.get("logger", None) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 请参考paddlescience格式修改:
|
||
if logger is not None: | ||
logger.update(x=xs[-1]) | ||
for i, j in zip(reversed(seq), reversed(seq_next)): | ||
with paddle.no_grad(): | ||
t = (paddle.ones(shape=n) * i).to(x.place) | ||
next_t = (paddle.ones(shape=n) * j).to(x.place) | ||
at = compute_alpha(b, t.astype(dtype="int64")) | ||
at_next = compute_alpha(b, next_t.astype(dtype="int64")) | ||
xt = xs[-1].to("cuda") | ||
et = model(xt, t) | ||
x0_t = (xt - et * (1 - at).sqrt()) / at.sqrt() | ||
x0_preds.append(x0_t.to("cpu")) | ||
c2 = (1 - at_next).sqrt() | ||
if dx_func is not None: | ||
dx = dx_func(xt) | ||
else: | ||
dx = 0 | ||
with paddle.no_grad(): | ||
xt_next = at_next.sqrt() * x0_t + c2 * et - dx | ||
if clamp_func is not None: | ||
xt_next = clamp_func(xt_next) | ||
xs.append(xt_next.to("cpu")) | ||
if logger is not None: | ||
logger.update(x=xs[-1]) | ||
if not cache: | ||
xs = xs[-1:] | ||
x0_preds = x0_preds[-1:] | ||
return xs, x0_preds | ||
|
||
|
||
def ddpm_steps(x, seq, model, b, **kwargs): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 同上面的函数,其他函数也是 |
||
n = x.shape[0] | ||
seq_next = [-1] + list(seq[:-1]) | ||
xs = [x] | ||
x0_preds = [] | ||
betas = b | ||
dx_func = kwargs.get("dx_func", None) | ||
cache = kwargs.get("cache", False) | ||
clamp_func = kwargs.get("clamp_func", None) | ||
for i, j in zip(reversed(seq), reversed(seq_next)): | ||
with paddle.no_grad(): | ||
t = (paddle.ones(shape=n) * i).to(x.place) | ||
next_t = (paddle.ones(shape=n) * j).to(x.place) | ||
at = compute_alpha(betas, t.astype(dtype="int64")) | ||
atm1 = compute_alpha(betas, next_t.astype(dtype="int64")) | ||
beta_t = 1 - at / atm1 | ||
x = xs[-1].to("cuda") | ||
output = model(x, t.astype(dtype="float32")) | ||
e = output | ||
x0_from_e = (1.0 / at).sqrt() * x - (1.0 / at - 1).sqrt() * e | ||
x0_from_e = paddle.clip(x=x0_from_e, min=-1, max=1) | ||
x0_preds.append(x0_from_e.to("cpu")) | ||
mean_eps = ( | ||
atm1.sqrt() * beta_t * x0_from_e + (1 - beta_t).sqrt() * (1 - atm1) * x | ||
) / (1.0 - at) | ||
mean = mean_eps | ||
noise = paddle.randn(shape=x.shape, dtype=x.dtype) | ||
mask = 1 - (t == 0).astype(dtype="float32") | ||
"""Class Method: *.view, can not convert, please check whether it is torch.Tensor.*/Optimizer.*/nn.Module.*/torch.distributions.Distribution.*/torch.autograd.function.FunctionCtx.*/torch.profiler.profile.*/torch.autograd.profiler.profile.*, and convert manually""" | ||
# >>>>>> mask = mask.view(-1, 1, 1, 1) | ||
mask = paddle.reshape(mask, shape=[-1, 1, 1, 1]) | ||
logvar = beta_t.log() | ||
if dx_func is not None: | ||
dx = dx_func(x) | ||
else: | ||
dx = 0 | ||
with paddle.no_grad(): | ||
sample = mean + mask * paddle.exp(x=0.5 * logvar) * noise - dx | ||
if clamp_func is not None: | ||
sample = clamp_func(sample) | ||
xs.append(sample.to("cpu")) | ||
if not cache: | ||
xs = xs[-1:] | ||
x0_preds = x0_preds[-1:] | ||
return xs, x0_preds | ||
|
||
|
||
def guided_ddpm_steps(x, seq, model, b, **kwargs): | ||
n = x.shape[0] | ||
seq_next = [-1] + list(seq[:-1]) | ||
xs = [x] | ||
x0_preds = [] | ||
betas = b | ||
dx_func = kwargs.get("dx_func", None) | ||
if dx_func is None: | ||
raise ValueError("dx_func is required for guided denoising") | ||
clamp_func = kwargs.get("clamp_func", None) | ||
cache = kwargs.get("cache", False) | ||
w = kwargs.get("w", 3.0) | ||
for i, j in zip(reversed(seq), reversed(seq_next)): | ||
with paddle.no_grad(): | ||
t = (paddle.ones(shape=n) * i).to(x.place) | ||
next_t = (paddle.ones(shape=n) * j).to(x.place) | ||
at = compute_alpha(betas, t.astype(dtype="int64")) | ||
atm1 = compute_alpha(betas, next_t.astype(dtype="int64")) | ||
beta_t = 1 - at / atm1 | ||
x = xs[-1].to("cuda") | ||
dx = dx_func(x) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 下面使用dx的时候都在with paddle.no_grad里,为什么这里要写在外面? |
||
with paddle.no_grad(): | ||
output = (w + 1) * model(x, t.astype(dtype="float32"), dx) - w * model( | ||
x, t.astype(dtype="float32") | ||
) | ||
e = output | ||
x0_from_e = (1.0 / at).sqrt() * x - (1.0 / at - 1).sqrt() * e | ||
x0_from_e = paddle.clip(x=x0_from_e, min=-1, max=1) | ||
x0_preds.append(x0_from_e.to("cpu")) | ||
mean_eps = ( | ||
atm1.sqrt() * beta_t * x0_from_e + (1 - beta_t).sqrt() * (1 - atm1) * x | ||
) / (1.0 - at) | ||
mean = mean_eps | ||
noise = paddle.randn(shape=x.shape, dtype=x.dtype) | ||
mask = 1 - (t == 0).astype(dtype="float32") | ||
"""Class Method: *.view, can not convert, please check whether it is torch.Tensor.*/Optimizer.*/nn.Module.*/torch.distributions.Distribution.*/torch.autograd.function.FunctionCtx.*/torch.profiler.profile.*/torch.autograd.profiler.profile.*, and convert manually""" | ||
# >>>>>> mask = mask.view(-1, 1, 1, 1) | ||
mask = paddle.reshape(mask, shape=[-1, 1, 1, 1]) | ||
logvar = beta_t.log() | ||
with paddle.no_grad(): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 为什么这部分代码不和上一个with paddle.no_grad写在一起? |
||
sample = mean + mask * paddle.exp(x=0.5 * logvar) * noise - dx | ||
if clamp_func is not None: | ||
sample = clamp_func(sample) | ||
xs.append(sample.to("cpu")) | ||
if not cache: | ||
xs = xs[-1:] | ||
x0_preds = x0_preds[-1:] | ||
return xs, x0_preds | ||
|
||
|
||
def guided_ddim_steps(x, seq, model, b, **kwargs): | ||
n = x.shape[0] | ||
seq_next = [-1] + list(seq[:-1]) | ||
x0_preds = [] | ||
xs = [x] | ||
dx_func = kwargs.get("dx_func", None) | ||
if dx_func is None: | ||
raise ValueError("dx_func is required for guided denoising") | ||
clamp_func = kwargs.get("clamp_func", None) | ||
cache = kwargs.get("cache", False) | ||
w = kwargs.get("w", 3.0) | ||
logger = kwargs.get("logger", None) | ||
if logger is not None: | ||
xs[-1] = paddle.to_tensor(xs[-1], place=paddle.CUDAPlace(0)) # 将张量转移到 CUDA 设备上 | ||
logger.update(x=xs[-1]) | ||
for i, j in zip(reversed(seq), reversed(seq_next)): | ||
with paddle.no_grad(): | ||
t = (paddle.ones(shape=n) * i).to("gpu") | ||
next_t = (paddle.ones(shape=n) * j).to("gpu") | ||
at = compute_alpha(b, t.astype(dtype="int64")) | ||
at_next = compute_alpha(b, next_t.astype(dtype="int64")) | ||
xt = xs[-1].to("gpu") | ||
dx = dx_func(xt) | ||
with paddle.no_grad(): | ||
et = (w + 1) * model(xt, t, dx) - w * model(xt, t) | ||
x0_t = (xt - et * (1 - at).sqrt()) / at.sqrt() | ||
x0_preds.append(x0_t.to("cpu")) | ||
c2 = (1 - at_next).sqrt() | ||
with paddle.no_grad(): | ||
xt_next = at_next.sqrt() * x0_t + c2 * et - dx | ||
if clamp_func is not None: | ||
xt_next = clamp_func(xt_next) | ||
xs.append(xt_next.to("cpu")) | ||
if logger is not None: | ||
logger.update(x=xs[-1]) | ||
if not cache: | ||
xs = xs[-1:] | ||
x0_preds = x0_preds[-1:] | ||
return xs, x0_preds |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
config文件格式请参考其他案例(如aneurysm)修改,增加hydra部分,剩余参数分为MODEL/TRAIN/EVAL和其他,并精简删除没有用到的参数