Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow any model architecture #3

Open
wants to merge 22 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
213 changes: 213 additions & 0 deletions examples/locoprop_torch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,213 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
from torch.backends import cudnn, cuda

from pytorch_lightning.callbacks import TQDMProgressBar
from pytorch_lightning.loggers import TensorBoardLogger
from locoprop.layer import LocoLayer, LocoLinear
from locoprop.trainer import LocopropTrainer
import locoprop.trainer as lctr
import pytorch_lightning

cudnn.benchmark = True
cudnn.deterministic = False
cudnn.allow_tf32 = True
cuda.matmul.allow_tf32 = True
cuda.matmul.allow_fp16_reduced_precision_reduction = False
cuda.matmul.allow_bf16_reduced_precision_reduction = False
torch.use_deterministic_algorithms(False)
torch.set_float32_matmul_precision("medium")


class LaProp(torch.optim.Optimizer):
def __init__(self, params, lr: float, betas: tuple = (0.9, 0.99), weight_decay: float = 1e-3,
gradient_clip_val: float = 1e-2, eps: float = 1e-8):
self.eps = eps # should depend on float accuracy, not parameter group
super().__init__(params, {"weight_decay": weight_decay, "lr": lr, "betas": betas,
"gradient_clip_val": gradient_clip_val})

def zero_grad(self, set_to_none: bool = True):
return # already set to none below

@torch.no_grad()
def step(self, closure=None):
if closure is None:
loss = None
else:
with torch.enable_grad():
loss = closure()

for group in self.param_groups:
for p in group["params"]:
if p.grad is None:
continue

# upcast for higher precision momentum / clipping
grad = p.grad
p.grad = None
p64 = p

# weight decay
if p.ndim > 1:
p64 *= 1 - group["weight_decay"] * group["lr"]

# adaptive gradient clipping
g_norm = grad.norm().clamp(min=self.eps) # 0-division-eps; Can be small due to fp64
p_norm = p64.norm().clamp(min=1e-3) # "parameter eps" to allow training of 0 params
scale = p_norm / g_norm * group["gradient_clip_val"]
grad *= scale.clamp(max=1)

# init state
beta1, beta2 = group["betas"]
state = self.state[p]
if len(state) == 0:
state["step"] = torch.tensor(0, dtype=torch.int64, device=p.device)
state["exp_avg"] = torch.zeros_like(p, dtype=torch.float32)
state["exp_avg_sq"] = torch.zeros_like(p, dtype=torch.float32)

# init variables
step = state["step"].add_(1)
exp_avg = state["exp_avg"]
exp_avg_sq = state["exp_avg_sq"]

# LaProp transform on gradient
exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)
exp_avg_sq_corrected = exp_avg_sq / (1 - beta2 ** step) # bias correction
grad /= torch.nan_to_num_(exp_avg_sq_corrected.sqrt(), 0, 0, 0).clamp(min=self.eps)
exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1)

# Graft to self
scale = -group["lr"]
scale /= p.numel() ** 0.5 # sign norm for grafting
scale *= 1 - beta1 ** step # bias correction
scale *= exp_avg.norm() # ema norm for grafting

p.data += torch.where(exp_avg > 0, scale, -scale)
return loss


class Flatten(nn.Module):
def forward(self, x):
return x.view(x.size(0), -1)


class LocoNet(nn.Sequential):
def __init__(self, input_dim=784, output_dim=784, hidden_dims=[1000, 500, 250, 30, 250, 500, 1000],
activation_cls=nn.Tanh, ):
super().__init__()
self.add_module("flatten", Flatten())
self.add_module("input", LocoLayer(nn.Linear(input_dim, hidden_dims[0]), activation_cls()))
for i, (d_in, d_out) in enumerate(zip(hidden_dims[:-1], hidden_dims[1:])):
self.add_module(f"stage{i}", LocoLayer(nn.Linear(d_in, d_out), activation_cls()))
self.add_module("output",
LocoLayer(nn.Linear(hidden_dims[-1], output_dim, bias=False), nn.Sigmoid(), implicit=True))


class LocoLinearNet(nn.Sequential):
def __init__(self, input_dim=784, output_dim=784, hidden_dims=[1000, 500, 250, 30, 250, 500, 1000],
activation_cls=nn.Tanh, ):
super().__init__()
self.add_module("flatten", Flatten())
self.add_module("input", LocoLinear(input_dim, hidden_dims[0], activation_cls()))
for i, (d_in, d_out) in enumerate(zip(hidden_dims[:-1], hidden_dims[1:])):
self.add_module(f"stage{i}", LocoLinear(d_in, d_out, activation_cls()))
self.add_module("output", LocoLinear(hidden_dims[-1], output_dim, nn.Sigmoid(), implicit=True))


class LocoConvNet(nn.Sequential):
def __init__(self, classes=10, kernel_size=5, stride=1, padding=2, activation_cls=nn.Tanh, ):
super().__init__()
self.add_module("stage_0", LocoLayer(nn.Conv2d(1, 16, kernel_size, stride, padding), activation_cls()))
self.add_module("pool_0", nn.MaxPool2d(2))
self.add_module("stage_1", LocoLayer(nn.Conv2d(16, 32, kernel_size, stride, padding), activation_cls()))
self.add_module("pool_1", nn.MaxPool2d(2))
self.add_module("flatten", Flatten())
self.add_module("output",
LocoLayer(nn.Linear(32 * 7 * 7, classes, bias=False), nn.Softmax(dim=-1), implicit=True))


def transform_data(train: bool, batch_size: int, move: int = 0):
with torch.no_grad():
# convert data to torch.FloatTensor
transform = torchvision.transforms.Compose(
[torchvision.transforms.Pad(move), torchvision.transforms.ToTensor()])

# load the training and test datasets
data = torchvision.datasets.MNIST(root='~/.pytorch/MNIST_data/', train=train, download=True,
transform=transform)
images, _labels = zip(*data)
data = torch.stack(images)
# data = torch.cat([data, 1 - data])
data = torch.cat([data[:, :, i:28 + i, j:28 + j] for i in range(move * 2 + 1) for j in range(move * 2 + 1)])
data = data[torch.randperm(data.shape[0])]
data = data[:data.shape[0] // batch_size * batch_size]
data = data.view(-1, batch_size, *data.shape[1:])
data.pin_memory()
return data


def get_dataloaders(batch_size=128):
return transform_data(True, batch_size, 2), transform_data(False, batch_size, 0)


class Model(pytorch_lightning.LightningModule):
def __init__(self, module: LocopropTrainer):
super().__init__()
self.module = module
if isinstance(module.model, LocoNet):
self.loss_fn = compute_loss
else:
self.loss_fn = F.cross_entropy

def forward(self, x):
return self.module(x)

def training_step(self, x, step):
loss = self.loss_fn(self(x), x.view(x.size(0), -1))
self.log("train/loss", loss)
return loss

def validation_step(self, x, step):
loss = self.loss_fn(self(x), x.view(x.size(0), -1))
self.log("val/loss", loss)
return loss

def configure_optimizers(self):
if lctr.IS_BASELINE:
return self.module.inner_opt_class(self.parameters(), **self.module.inner_opt_hparams)
return torch.optim.SGD(self.parameters(), lr=1)


def compute_loss(logits, y):
return F.binary_cross_entropy_with_logits(logits, y) * logits.size(-1)


# LocoNet or LocoConvNet
def run(cls: type = LocoLinearNet, epochs: int = 100, batch_size: int = 16384):
device = "cuda" if torch.cuda.is_available() else "cpu"
torch.manual_seed(1)

print("Device:", device)

train_dl, val_dl = get_dataloaders(batch_size)

def _train():
model = Model(LocopropTrainer(cls(), inner_opt_class=LaProp, learning_rate=1, correction=0, iterations=100,
inner_opt_hparams=dict(lr=1e-3, betas=(0.9, 0.9)))).to(device)

trainer = pytorch_lightning.Trainer(accelerator="gpu", devices="auto", #
callbacks=[TQDMProgressBar()], log_every_n_steps=8,
enable_checkpointing=False, max_epochs=epochs,
logger=TensorBoardLogger("logs", name=str(lctr.IS_BASELINE)))
trainer.fit(model, train_dl, val_dl)

return _train


train_fn = run()

# lctr.IS_BASELINE = True; train_fn()
lctr.IS_BASELINE = False;
train_fn()
Loading