Skip to content

batch_enhancement #51

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Dec 12, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 9 additions & 1 deletion pina/label_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ def labels(self):

@labels.setter
def labels(self, labels):
if len(labels) != self.shape[1]: # small check
if len(labels) != self.shape[1]: # small check
raise ValueError(
'the tensor has not the same number of columns of '
'the passed labels.')
Expand All @@ -106,6 +106,14 @@ def to(self, *args, **kwargs):
new.data = tmp.data
return new

def select(self, *args, **kwargs):
"""
Performs Tensor selection. For more details, see :meth:`torch.Tensor.select`.
"""
tmp = super().select(*args, **kwargs)
tmp._labels = self._labels
return tmp

def extract(self, label_to_extract):
"""
Extract the subset of the original tensor by returning all the columns
Expand Down
171 changes: 115 additions & 56 deletions pina/pinn.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,30 +3,41 @@

from .problem import AbstractProblem
from .label_tensor import LabelTensor
from .utils import merge_tensors
from .utils import merge_tensors, PinaDataset

torch.pi = torch.acos(torch.zeros(1)).item() * 2 # which is 3.1415927410125732

torch.pi = torch.acos(torch.zeros(1)).item() * 2 # which is 3.1415927410125732


class PINN(object):

def __init__(self,
problem,
model,
optimizer=torch.optim.Adam,
lr=0.001,
regularizer=0.00001,
dtype=torch.float32,
device='cpu',
error_norm='mse'):
problem,
model,
optimizer=torch.optim.Adam,
lr=0.001,
regularizer=0.00001,
batch_size=None,
dtype=torch.float32,
device='cpu',
error_norm='mse'):
'''
:param Problem problem: the formualation of the problem.
:param torch.nn.Module model: the neural network model to use.
:param torch.optim optimizer: the neural network optimizer to use;
default is `torch.optim.Adam`.
:param float lr: the learning rate; default is 0.001.
:param float regularizer: the coefficient for L2 regularizer term.
:param type dtype: the data type to use for the model. Valid option are
`torch.float32` and `torch.float64` (`torch.float16` only on GPU);
default is `torch.float64`.
:param string device: the device used for training; default 'cpu'
option include 'cuda' if cuda is available.
:param string/int error_norm: the loss function used as minimizer,
default mean square error 'mse'. If string options include mean
error 'me' and mean square error 'mse'. If int, the p-norm is
calculated where p is specifined by the int input.
:param int batch_size: batch size for the dataloader; default 5.
'''

if dtype == torch.float64:
Expand All @@ -38,7 +49,7 @@ def __init__(self,
# self._architecture['input_dimension'] = self.problem.domain_bound.shape[0]
# self._architecture['output_dimension'] = len(self.problem.variables)
# if hasattr(self.problem, 'params_domain'):
# self._architecture['input_dimension'] += self.problem.params_domain.shape[0]
# self._architecture['input_dimension'] += self.problem.params_domain.shape[0]

self.error_norm = error_norm

Expand All @@ -59,6 +70,9 @@ def __init__(self,
self.optimizer = optimizer(
self.model.parameters(), lr=lr, weight_decay=regularizer)

self.batch_size = batch_size
self.data_set = PinaDataset(self)

@property
def problem(self):
return self._problem
Expand All @@ -79,7 +93,7 @@ def _compute_norm(self, vec):
:param vec torch.tensor: the tensor
"""
if isinstance(self.error_norm, int):
return torch.linalg.vector_norm(vec, ord = self.error_norm, dtype=self.dytpe)
return torch.linalg.vector_norm(vec, ord=self.error_norm, dtype=self.dytpe)
elif self.error_norm == 'mse':
return torch.mean(vec.pow(2))
elif self.error_norm == 'me':
Expand All @@ -90,16 +104,16 @@ def _compute_norm(self, vec):
def save_state(self, filename):

checkpoint = {
'epoch': self.trained_epoch,
'model_state': self.model.state_dict(),
'optimizer_state' : self.optimizer.state_dict(),
'optimizer_class' : self.optimizer.__class__,
'history' : self.history_loss,
'input_points_dict' : self.input_pts,
'epoch': self.trained_epoch,
'model_state': self.model.state_dict(),
'optimizer_state': self.optimizer.state_dict(),
'optimizer_class': self.optimizer.__class__,
'history': self.history_loss,
'input_points_dict': self.input_pts,
}

# TODO save also architecture param?
#if isinstance(self.model, DeepFeedForward):
# if isinstance(self.model, DeepFeedForward):
# checkpoint['model_class'] = self.model.__class__
# checkpoint['model_structure'] = {
# }
Expand All @@ -110,7 +124,6 @@ def load_state(self, filename):
checkpoint = torch.load(filename)
self.model.load_state_dict(checkpoint['model_state'])


self.optimizer = checkpoint['optimizer_class'](self.model.parameters())
self.optimizer.load_state_dict(checkpoint['optimizer_state'])

Expand All @@ -121,6 +134,39 @@ def load_state(self, filename):

return self

def _create_dataloader(self):
"""Private method for creating dataloader

:return: dataloader
:rtype: torch.utils.data.DataLoader
"""
if self.batch_size is None:
return [self.input_pts]

def custom_collate(batch):
# extracting pts labels
_, pts = list(batch[0].items())[0]
labels = pts.labels
# calling default torch collate
collate_res = default_collate(batch)
# save collate result in dict
res = {}
for key, val in collate_res.items():
val.labels = labels
res[key] = val
return res

# creating dataset, list of dataset for each location
datasets = [MyDataSet(key, val)
for key, val in self.input_pts.items()]
# creating dataloader
dataloaders = [DataLoader(dataset=dat,
batch_size=self.batch_size,
collate_fn=custom_collate)
for dat in datasets]

return dict(zip(self.input_pts.keys(), dataloaders))

def span_pts(self, *args, **kwargs):
"""
>>> pinn.span_pts(n=10, mode='grid')
Expand Down Expand Up @@ -155,59 +201,69 @@ def span_pts(self, *args, **kwargs):
argument['n'],
argument['mode'],
variables=argument['variables'])
for argument in arguments)
for argument in arguments)
pts = merge_tensors(samples)

# TODO
# pts = pts.double()
pts = pts.to(dtype=self.dtype, device=self.device)
pts.requires_grad_(True)
pts.retain_grad()

self.input_pts[location] = pts

def train(self, stop=100, frequency_print=2, save_loss=1, trial=None):

epoch = 0
data_loader = self.data_set.dataloader

header = []
for condition_name in self.problem.conditions:
condition = self.problem.conditions[condition_name]

if (hasattr(condition, 'function') and
isinstance(condition.function, list)):
for function in condition.function:
header.append(f'{condition_name}{function.__name__}')
else:
header.append(f'{condition_name}')
if hasattr(condition, 'function'):
if isinstance(condition.function, list):
for function in condition.function:
header.append(f'{condition_name}{function.__name__}')

continue

header.append(f'{condition_name}')

while True:

losses = []

for condition_name in self.problem.conditions:
condition = self.problem.conditions[condition_name]

if hasattr(condition, 'function'):
pts = self.input_pts[condition_name]
predicted = self.model(pts)
for function in condition.function:
residuals = function(pts, predicted)
for batch in data_loader[condition_name]:

single_loss = []

if hasattr(condition, 'function'):
pts = batch[condition_name]
pts = pts.to(dtype=self.dtype, device=self.device)
pts.requires_grad_(True)
pts.retain_grad()

predicted = self.model(pts)
for function in condition.function:
residuals = function(pts, predicted)
local_loss = (
condition.data_weight*self._compute_norm(
residuals))
single_loss.append(local_loss)
elif hasattr(condition, 'output_points'):
pts = condition.input_points.to(
dtype=self.dtype, device=self.device)
predicted = self.model(pts)
residuals = predicted - condition.output_points
local_loss = (
condition.data_weight*self._compute_norm(
residuals))
losses.append(local_loss)
elif hasattr(condition, 'output_points'):
pts = condition.input_points
predicted = self.model(pts)
residuals = predicted - condition.output_points
local_loss = (
condition.data_weight*self._compute_norm(residuals))
losses.append(local_loss)

self.optimizer.zero_grad()

sum(losses).backward()
self.optimizer.step()
condition.data_weight*self._compute_norm(residuals))
single_loss.append(local_loss)

self.optimizer.zero_grad()
sum(single_loss).backward()
self.optimizer.step()

losses.append(sum(single_loss))

if save_loss and (epoch % save_loss == 0 or epoch == 0):
self.history_loss[epoch] = [
Expand All @@ -221,7 +277,8 @@ def train(self, stop=100, frequency_print=2, save_loss=1, trial=None):

if isinstance(stop, int):
if epoch == stop:
print('[epoch {:05d}] {:.6e} '.format(self.trained_epoch, sum(losses).item()), end='')
print('[epoch {:05d}] {:.6e} '.format(
self.trained_epoch, sum(losses).item()), end='')
for loss in losses:
print('{:.6e} '.format(loss.item()), end='')
print()
Expand All @@ -236,7 +293,8 @@ def train(self, stop=100, frequency_print=2, save_loss=1, trial=None):
print('{:12.12s} '.format(name), end='')
print()

print('[epoch {:05d}] {:.6e} '.format(self.trained_epoch, sum(losses).item()), end='')
print('[epoch {:05d}] {:.6e} '.format(
self.trained_epoch, sum(losses).item()), end='')
for loss in losses:
print('{:.6e} '.format(loss.item()), end='')
print()
Expand All @@ -246,7 +304,6 @@ def train(self, stop=100, frequency_print=2, save_loss=1, trial=None):

return sum(losses).item()


def error(self, dtype='l2', res=100):

import numpy as np
Expand All @@ -261,7 +318,8 @@ def error(self, dtype='l2', res=100):
grids_container = self.problem.data_solution['grid']
Z_true = self.problem.data_solution['grid_solution']
try:
unrolled_pts = torch.tensor([t.flatten() for t in grids_container]).T.to(dtype=self.dtype, device=self.device)
unrolled_pts = torch.tensor([t.flatten() for t in grids_container]).T.to(
dtype=self.dtype, device=self.device)
Z_pred = self.model(unrolled_pts)
Z_pred = Z_pred.detach().numpy().reshape(grids_container[0].shape)

Expand All @@ -273,4 +331,5 @@ def error(self, dtype='l2', res=100):
except:
print("")
print("Something went wrong...")
print("Not able to compute the error. Please pass a data solution or a true solution")
print(
"Not able to compute the error. Please pass a data solution or a true solution")
Loading