Skip to content

Commit 7673ae4

Browse files
authored
Updated vcae experiment (#45)
* latest experiment * refactor * add SVHN * proper flattening * increase number of epochs * add notebooks * add the option to save results * add the option to save results+1 * cleanup * refactor and simplify
1 parent ae28691 commit 7673ae4

File tree

17 files changed

+1009
-548
lines changed

17 files changed

+1009
-548
lines changed

.gitignore

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,9 @@ MNIST/
1111
logs/
1212
examples/MNIST
1313
examples/logs
14+
examples/checkpoints
1415
examples/**/*.gz
15-
examples/**/*.ipynb
16+
# examples/**/*.ipynb
1617

1718
# Python-generated files
1819
__pycache__/
@@ -239,3 +240,6 @@ Temporary Items
239240

240241
# Built Visual Studio Code Extensions
241242
*.vsix
243+
244+
*.csv
245+
*.mat

README.md

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ Visit the [`./examples/`](https://github.com/TY-Cheng/torchvinecopulib/tree/main
4141

4242
## Installation
4343

44-
- By `pip` from [`PyPI`](https://pypi.org/project/torchvinecopulib/):
44+
- By `pip` from [`PyPI`](https://pypi.org/project/torchvinecopulib/) (see the dependencies and uv sections below for CUDA support):
4545

4646
```bash
4747
pip install torchvinecopulib torch
@@ -62,10 +62,19 @@ pip install ./dist/torchvinecopulib-1.1.0.tar.gz
6262
After `git clone https://github.com/TY-Cheng/torchvinecopulib.git`, `cd` into the project root where [`pyproject.toml`](https://github.com/TY-Cheng/torchvinecopulib/blob/main/pyproject.toml) exists,
6363

6464
```bash
65-
# inside project root folder
66-
uv sync --extra cpu -U
67-
# or
68-
uv sync --extra cu126 -U
65+
# From inside the project root folder
66+
# Create and activate local virtual environment
67+
uv venv .venv
68+
source .venv/bin/activate
69+
70+
# Sync dependencies with CPU support (default)
71+
uv sync --extra cpu
72+
73+
# Or for CUDA 12.6 or 12.8 support (depends on your CUDA version)
74+
uv sync --extra cu126
75+
76+
# Additionally, to install additional dependencies for the examples
77+
uv sync --extra examples
6978
```
7079

7180
## Dependencies
@@ -81,6 +90,7 @@ scipy = "*"
8190
torch = [
8291
{ index = "torch-cpu", extra = "cpu" },
8392
{ index = "torch-cu126", extra = "cu126" },
93+
{ index = "torch-cu128", extra = "cu128" },
8494
]
8595
```
8696

examples/vcae_mnist/run_seeds.py renamed to examples/vcae/run_seeds.py

Lines changed: 16 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2,17 +2,27 @@
22
import logging
33
import os
44
import sys
5+
from typing import Union
56

67
import pandas as pd
78
from tqdm import tqdm
8-
from vcae_mnist.config import config
9-
from vcae_mnist.experiment import run_experiment
9+
from vcae.config import config_mnist, config_svhn
10+
from vcae.experiment import run_experiment
1011

12+
dataset = "MNIST" # or "SVHN"
1113
start = int(sys.argv[1])
1214
end = int(sys.argv[2])
1315

16+
if dataset == "MNIST":
17+
config = config_mnist
18+
elif dataset == "SVHN":
19+
config = config_svhn
20+
else:
21+
raise ValueError(f"Unsupported dataset: {dataset}")
22+
23+
1424
# Redirect tqdm and errors to log file
15-
log_path = f"progress_{start}_{end}.log"
25+
log_path = f"progress_{dataset}_{start}_{end}.log"
1626
log_file = open(log_path, "w")
1727

1828
logging.basicConfig(
@@ -33,14 +43,12 @@ def suppress_output():
3343
logging.getLogger().setLevel(logging_level)
3444

3545

36-
results = []
37-
38-
output_path = f"results_{start}_{end}.csv"
39-
46+
results: list[dict[str, Union[float, int, str]]] = []
47+
output_path = f"results_{dataset}_{start}_{end}.csv"
4048
for seed in tqdm(range(start, end), desc=f"Seeds {start}-{end}", file=log_file):
4149
try:
4250
with suppress_output():
43-
result = run_experiment(seed, config)
51+
result = run_experiment(seed, config, dataset=dataset)
4452
df = pd.DataFrame([result])
4553

4654
# Write headers only once

examples/vcae/run_seeds.sh

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
#!/bin/bash
2+
3+
PYTHON_BIN=$(which python)
4+
USE_NOHUP=false
5+
6+
# Defaults
7+
START=0
8+
END=30
9+
STEP=10
10+
11+
# Parse arguments
12+
POSITIONAL=()
13+
while [[ $# -gt 0 ]]; do
14+
case "$1" in
15+
--nohup)
16+
USE_NOHUP=true
17+
shift
18+
;;
19+
*)
20+
POSITIONAL+=("$1")
21+
shift
22+
;;
23+
esac
24+
done
25+
26+
# Restore positional args
27+
set -- "${POSITIONAL[@]}"
28+
29+
# Assign range values if provided
30+
if [[ $# -ge 1 ]]; then START=$1; fi
31+
if [[ $# -ge 2 ]]; then END=$2; fi
32+
if [[ $# -ge 3 ]]; then STEP=$3; fi
33+
34+
# Validate input
35+
if (( STEP <= 0 )); then
36+
echo "Error: STEP must be a positive integer." >&2
37+
exit 1
38+
fi
39+
40+
if (( (END - START) % STEP != 0 )); then
41+
echo "Error: (END - START) must be divisible by STEP." >&2
42+
exit 1
43+
fi
44+
45+
echo "Using Python binary: $PYTHON_BIN"
46+
echo "Using nohup: $USE_NOHUP"
47+
echo "Range: $START to $END with step $STEP"
48+
49+
# Launch loop
50+
for ((i = START; i < END; i += STEP)); do
51+
j=$((i + STEP))
52+
name="seeds_${i}_${j}"
53+
echo "Launching $name"
54+
if $USE_NOHUP; then
55+
nohup "$PYTHON_BIN" run_seeds.py $i $j > logs/$name.log 2>&1 &
56+
else
57+
"$PYTHON_BIN" run_seeds.py $i $j > logs/$name.log 2>&1 &
58+
fi
59+
done
60+
61+
wait
File renamed without changes.

examples/vcae/vcae/config.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
import os
2+
from dataclasses import dataclass
3+
4+
import torch
5+
6+
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
7+
torch.set_float32_matmul_precision("medium")
8+
9+
10+
@dataclass
11+
class Config:
12+
# Reproducibility
13+
seed: int = 42
14+
15+
# Training-related
16+
data_dir: str = os.environ.get("PATH_DATASETS", ".")
17+
save_dir: str = "logs/"
18+
batch_size: int = 512 if torch.cuda.is_available() else 64
19+
max_epochs: int = 10
20+
accelerator: str = DEVICE
21+
devices: int = 1
22+
num_workers: int = 1 # or min(15, os.cpu_count())
23+
24+
# Data-related
25+
dims: tuple[int, ...] = (1, 28, 28)
26+
val_train_split: float = 0.1
27+
28+
# Model-related
29+
hidden_size: int = 64
30+
latent_size: int = 10
31+
learning_rate: float = 2e-4
32+
vine_lambda: float = 0.0
33+
# use_mmd: bool = False
34+
# mmd_sigmas: list[float] = [1e-1, 1, 10]
35+
# mmd_lambda: float = 10.0
36+
37+
config_mnist = Config(
38+
max_epochs=10,
39+
dims=(1, 28, 28),
40+
hidden_size=64,
41+
latent_size=10,
42+
)
43+
44+
config_svhn = Config(
45+
max_epochs=50,
46+
dims=(3, 32, 32),
47+
hidden_size=128,
48+
latent_size=32,
49+
)

examples/vcae_mnist/vcae_mnist/experiment.py renamed to examples/vcae/vcae/experiment.py

Lines changed: 37 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,13 @@
1-
import copy
21
import random
2+
from typing import Union
33

44
import numpy as np
55
import pytorch_lightning as pl
66
import torch
77

88
from .config import DEVICE, Config
99
from .metrics import compute_score
10-
from .model import LitMNISTAutoencoder
10+
from .model import LitAutoencoder, LitMNISTAutoencoder, LitSVHNAutoencoder
1111

1212

1313
def set_seed(seed: int):
@@ -20,11 +20,21 @@ def set_seed(seed: int):
2020
torch.backends.cudnn.deterministic = True
2121

2222

23-
def run_experiment(seed: int, config: Config):
23+
def run_experiment(
24+
seed: int, config: Config, vine_lambda: float = 1.0, dataset: str = "MNIST"
25+
) -> dict[str, Union[float, int, str]]:
26+
# Set the seed for reproducibility
2427
set_seed(seed)
28+
config.seed = seed
2529

2630
# Instantiate the model
27-
model_initial = LitMNISTAutoencoder()
31+
model_initial: LitAutoencoder
32+
if dataset == "MNIST":
33+
model_initial = LitMNISTAutoencoder(config)
34+
elif dataset == "SVHN":
35+
model_initial = LitSVHNAutoencoder(config)
36+
else:
37+
raise ValueError(f"Unsupported dataset: {dataset}")
2838

2939
# Set up trainer
3040
trainer_initial = pl.Trainer(
@@ -46,10 +56,16 @@ def run_experiment(seed: int, config: Config):
4656
model_initial.learn_vine(n_samples=5000)
4757

4858
# Extract test data
49-
rep_initial, _, data_initial, _, samples_initial = model_initial.get_data(stage="test")
59+
rep_initial, _, data_initial, decoded_initial, samples_initial = model_initial.get_data(
60+
stage="test"
61+
)
62+
63+
# Reset the seed for refitting to avoid data leakage
64+
set_seed(seed)
5065

51-
# Deepcopy for refit
52-
model_refit = copy.deepcopy(model_initial)
66+
# Create a new model with the same configuration but reset vine lambda
67+
config.vine_lambda = vine_lambda
68+
model_refit = model_initial.copy_with_config(config)
5369

5470
# Set up trainer for refitting
5571
trainer_refit = pl.Trainer(
@@ -68,21 +84,29 @@ def run_experiment(seed: int, config: Config):
6884
model_refit.to(DEVICE)
6985

7086
# Extract test data
71-
rep_refit, _, data_refit, _, samples_refit = model_refit.get_data(stage="test")
87+
rep_refit, _, data_refit, decoded_refit, samples_refit = model_refit.get_data(stage="test")
7288

89+
assert model_initial.vine is not None
90+
assert model_refit.vine is not None
7391
loglik_initial = model_initial.vine.log_pdf(rep_initial).mean().item()
7492
loglik_refit = model_refit.vine.log_pdf(rep_refit).mean().item()
7593

94+
mse_initial = torch.nn.functional.mse_loss(decoded_initial, data_initial).item()
95+
mse_refit = torch.nn.functional.mse_loss(decoded_refit, data_refit).item()
96+
7697
sigmas = [1e-3, 1e-2, 1e-1, 1, 10, 100]
77-
score_initial = compute_score(data_initial, samples_initial, DEVICE, sigmas=sigmas)
78-
score_refit = compute_score(data_refit, samples_refit, DEVICE, sigmas=sigmas)
98+
score_initial = compute_score(data_initial, samples_initial, sigmas=sigmas)
99+
score_refit = compute_score(data_refit, samples_refit, sigmas=sigmas)
79100

80101
return {
81102
"seed": seed,
82-
"loglik": loglik_initial,
103+
"dataset": dataset,
104+
"mse_initial": mse_initial,
105+
"mse_refit": mse_refit,
106+
"loglik_initial": loglik_initial,
83107
"loglik_refit": loglik_refit,
84-
"mmd": score_initial.mmd,
108+
"mmd_initial": score_initial.mmd,
85109
"mmd_refit": score_refit.mmd,
86-
"fid": score_initial.fid,
110+
"fid_initial": score_initial.fid,
87111
"fid_refit": score_refit.fid,
88112
}

examples/vcae/vcae/metrics.py

Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
import numpy as np
2+
import torch
3+
from scipy import linalg
4+
5+
6+
def mmd(real: torch.Tensor, fake: torch.Tensor, sigmas=[1e-3, 1e-2, 1e-1, 1, 10, 100]):
7+
"""
8+
Differentiable MMD loss using Gaussian kernels with fixed sigmas and
9+
distance normalization via Mxx.mean().
10+
11+
Parameters
12+
----------
13+
real : (n, d) tensor
14+
Batch of real samples (features or images).
15+
fake : (m, d) tensor
16+
Batch of generated samples.
17+
sigmas : list of float
18+
Bandwidths for the RBF kernel. Defaults to wide, fixed list.
19+
20+
Returns
21+
-------
22+
mmd : scalar tensor
23+
Differentiable scalar loss value.
24+
"""
25+
real = real.view(real.size(0), -1)
26+
fake = fake.view(fake.size(0), -1)
27+
28+
def pairwise_squared_distances(x, y):
29+
x_norm = (x**2).sum(dim=1, keepdim=True)
30+
y_norm = (y**2).sum(dim=1, keepdim=True)
31+
return x_norm + y_norm.T - 2.0 * x @ y.T
32+
33+
Mxx = pairwise_squared_distances(real, real)
34+
Mxy = pairwise_squared_distances(real, fake)
35+
Myy = pairwise_squared_distances(fake, fake)
36+
37+
# Normalization factor based on real-real distances
38+
scale = Mxx.mean().detach()
39+
40+
mmd_total = 0.0
41+
for sigma in sigmas:
42+
denom = scale * 2.0 * sigma**2
43+
Kxx = torch.exp(-Mxx / denom)
44+
Kxy = torch.exp(-Mxy / denom)
45+
Kyy = torch.exp(-Myy / denom)
46+
47+
mmd_total += Kxx.mean() + Kyy.mean() - 2.0 * Kxy.mean()
48+
49+
return mmd_total / len(sigmas)
50+
51+
52+
def fid(X, Y):
53+
m = X.mean(0)
54+
m_w = Y.mean(0)
55+
X_np = X.numpy()
56+
Y_np = Y.numpy()
57+
58+
C = np.cov(X_np.transpose())
59+
C_w = np.cov(Y_np.transpose())
60+
C_C_w_sqrt = linalg.sqrtm(C.dot(C_w), True).real
61+
62+
score = m.dot(m) + m_w.dot(m_w) - 2 * m_w.dot(m) + np.trace(C + C_w - 2 * C_C_w_sqrt)
63+
return np.sqrt(score)
64+
65+
66+
class Score:
67+
mmd = 0
68+
fid = 0
69+
70+
71+
def compute_score(real, fake, sigmas=[1e-3, 1e-2, 1e-1, 1, 10, 100]):
72+
real = real.to("cpu")
73+
fake = fake.to("cpu")
74+
75+
s = Score()
76+
s.mmd = np.sqrt(mmd(real, fake, sigmas).numpy())
77+
s.fid = fid(fake, real).numpy()
78+
79+
return s

0 commit comments

Comments
 (0)