-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathdiffusion_params.py
68 lines (58 loc) · 2.28 KB
/
diffusion_params.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
import platform
import torch
from parse import *
def get_model_parameters_for_diffusion_from_string(string):
result = parse('pretrained/{}_arch{}_e{}_d{}_edim{}_ks{}_par{}_date{}', string)
# system_name, arch, num_epochs, num_diffusion_iters, diffusion_step_embed_dim, kernel_size, num_param, date_time
system_name, arch, _, num_diffusion_iters, diffusion_step_embed_dim, kernel_size, _, _ = result
num_diffusion_iters = int(num_diffusion_iters)
diffusion_step_embed_dim = int(diffusion_step_embed_dim)
kernel_size = int(kernel_size)
down_dims = [int(i) for i in arch.split('_')]
params = get_model_parameters_for_diffusion()
params['SYSTEM_NAME'] = system_name
params['KERNEL_SIZE'] = kernel_size
params['NUM_DIFFUSION_ITERS'] = num_diffusion_iters
params['DIFFUSION_STEP_EMBEDDING_DIM'] = diffusion_step_embed_dim
params['DOWN_DIMS'] = down_dims
params['SHRINK'] = None
return params
def get_model_parameters_for_diffusion():
params = {
# Weights and Biases - params
'PROJECT_NAME': 'diffusion',
'ENTITY': 'dl-282',
# ID
'ID': 0,
# Hardware
'DEVICE': 'cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu',
# General
'SYSTEM_NAME': '2d',
'SHRINK': 1,
'LOG_DIR': 'logs',
'NUM_EPOCHS': 100,
'DTYPE': torch.float32,
# Model Architecture
'KERNEL_SIZE': 5,
'DOWN_DIMS': [256],
'N_GROUPS': 2,
# Diffusion
'NUM_DIFFUSION_ITERS': 100,
'DIFFUSION_STEP_EMBEDDING_DIM': 256,
# Exponential Moving Average
'EMA_POWER': 0.75,
# Optimizer
'OPTIMIZER': 'adamw',
'LEARNING_RATE': 9e-5,
'WEIGHT_DECAY': 3e-5,
'COSINE_LR_NUM_WARMUP_STEPS': 500,
'WANDB': False
}
params['IS_M1_ARCH'] = True if params['DEVICE'] == 'mps' else False
params['BATCH_SIZE'] = 128 if params['IS_M1_ARCH'] else 256
params['NUM_WORKERS'] = 2 if params['IS_M1_ARCH'] else 4
params['ARCH'] = str(params['DOWN_DIMS'])[1:-1].replace(', ', '_')
# Apply shrinkage to the model
if params['SHRINK'] > 1:
params['DOWN_DIMS'] = [d // params['SHRINK'] for d in params['DOWN_DIMS']]
return params