diff --git a/pina/label_tensor.py b/pina/label_tensor.py index 53449300e..1f0d5be85 100644 --- a/pina/label_tensor.py +++ b/pina/label_tensor.py @@ -79,7 +79,7 @@ def labels(self): @labels.setter def labels(self, labels): - if len(labels) != self.shape[1]: # small check + if len(labels) != self.shape[1]: # small check raise ValueError( 'the tensor has not the same number of columns of ' 'the passed labels.') @@ -106,6 +106,14 @@ def to(self, *args, **kwargs): new.data = tmp.data return new + def select(self, *args, **kwargs): + """ + Performs Tensor selection. For more details, see :meth:`torch.Tensor.select`. + """ + tmp = super().select(*args, **kwargs) + tmp._labels = self._labels + return tmp + def extract(self, label_to_extract): """ Extract the subset of the original tensor by returning all the columns diff --git a/pina/pinn.py b/pina/pinn.py index 1ca26360b..e990da4d0 100644 --- a/pina/pinn.py +++ b/pina/pinn.py @@ -3,30 +3,41 @@ from .problem import AbstractProblem from .label_tensor import LabelTensor -from .utils import merge_tensors +from .utils import merge_tensors, PinaDataset -torch.pi = torch.acos(torch.zeros(1)).item() * 2 # which is 3.1415927410125732 + +torch.pi = torch.acos(torch.zeros(1)).item() * 2 # which is 3.1415927410125732 class PINN(object): def __init__(self, - problem, - model, - optimizer=torch.optim.Adam, - lr=0.001, - regularizer=0.00001, - dtype=torch.float32, - device='cpu', - error_norm='mse'): + problem, + model, + optimizer=torch.optim.Adam, + lr=0.001, + regularizer=0.00001, + batch_size=None, + dtype=torch.float32, + device='cpu', + error_norm='mse'): ''' :param Problem problem: the formualation of the problem. :param torch.nn.Module model: the neural network model to use. + :param torch.optim optimizer: the neural network optimizer to use; + default is `torch.optim.Adam`. :param float lr: the learning rate; default is 0.001. :param float regularizer: the coefficient for L2 regularizer term. :param type dtype: the data type to use for the model. Valid option are `torch.float32` and `torch.float64` (`torch.float16` only on GPU); default is `torch.float64`. + :param string device: the device used for training; default 'cpu' + option include 'cuda' if cuda is available. + :param string/int error_norm: the loss function used as minimizer, + default mean square error 'mse'. If string options include mean + error 'me' and mean square error 'mse'. If int, the p-norm is + calculated where p is specifined by the int input. + :param int batch_size: batch size for the dataloader; default 5. ''' if dtype == torch.float64: @@ -38,7 +49,7 @@ def __init__(self, # self._architecture['input_dimension'] = self.problem.domain_bound.shape[0] # self._architecture['output_dimension'] = len(self.problem.variables) # if hasattr(self.problem, 'params_domain'): - # self._architecture['input_dimension'] += self.problem.params_domain.shape[0] + # self._architecture['input_dimension'] += self.problem.params_domain.shape[0] self.error_norm = error_norm @@ -59,6 +70,9 @@ def __init__(self, self.optimizer = optimizer( self.model.parameters(), lr=lr, weight_decay=regularizer) + self.batch_size = batch_size + self.data_set = PinaDataset(self) + @property def problem(self): return self._problem @@ -79,7 +93,7 @@ def _compute_norm(self, vec): :param vec torch.tensor: the tensor """ if isinstance(self.error_norm, int): - return torch.linalg.vector_norm(vec, ord = self.error_norm, dtype=self.dytpe) + return torch.linalg.vector_norm(vec, ord=self.error_norm, dtype=self.dytpe) elif self.error_norm == 'mse': return torch.mean(vec.pow(2)) elif self.error_norm == 'me': @@ -90,16 +104,16 @@ def _compute_norm(self, vec): def save_state(self, filename): checkpoint = { - 'epoch': self.trained_epoch, - 'model_state': self.model.state_dict(), - 'optimizer_state' : self.optimizer.state_dict(), - 'optimizer_class' : self.optimizer.__class__, - 'history' : self.history_loss, - 'input_points_dict' : self.input_pts, + 'epoch': self.trained_epoch, + 'model_state': self.model.state_dict(), + 'optimizer_state': self.optimizer.state_dict(), + 'optimizer_class': self.optimizer.__class__, + 'history': self.history_loss, + 'input_points_dict': self.input_pts, } # TODO save also architecture param? - #if isinstance(self.model, DeepFeedForward): + # if isinstance(self.model, DeepFeedForward): # checkpoint['model_class'] = self.model.__class__ # checkpoint['model_structure'] = { # } @@ -110,7 +124,6 @@ def load_state(self, filename): checkpoint = torch.load(filename) self.model.load_state_dict(checkpoint['model_state']) - self.optimizer = checkpoint['optimizer_class'](self.model.parameters()) self.optimizer.load_state_dict(checkpoint['optimizer_state']) @@ -121,6 +134,39 @@ def load_state(self, filename): return self + def _create_dataloader(self): + """Private method for creating dataloader + + :return: dataloader + :rtype: torch.utils.data.DataLoader + """ + if self.batch_size is None: + return [self.input_pts] + + def custom_collate(batch): + # extracting pts labels + _, pts = list(batch[0].items())[0] + labels = pts.labels + # calling default torch collate + collate_res = default_collate(batch) + # save collate result in dict + res = {} + for key, val in collate_res.items(): + val.labels = labels + res[key] = val + return res + + # creating dataset, list of dataset for each location + datasets = [MyDataSet(key, val) + for key, val in self.input_pts.items()] + # creating dataloader + dataloaders = [DataLoader(dataset=dat, + batch_size=self.batch_size, + collate_fn=custom_collate) + for dat in datasets] + + return dict(zip(self.input_pts.keys(), dataloaders)) + def span_pts(self, *args, **kwargs): """ >>> pinn.span_pts(n=10, mode='grid') @@ -155,59 +201,69 @@ def span_pts(self, *args, **kwargs): argument['n'], argument['mode'], variables=argument['variables']) - for argument in arguments) + for argument in arguments) pts = merge_tensors(samples) # TODO # pts = pts.double() - pts = pts.to(dtype=self.dtype, device=self.device) - pts.requires_grad_(True) - pts.retain_grad() - self.input_pts[location] = pts def train(self, stop=100, frequency_print=2, save_loss=1, trial=None): epoch = 0 + data_loader = self.data_set.dataloader header = [] for condition_name in self.problem.conditions: condition = self.problem.conditions[condition_name] - if (hasattr(condition, 'function') and - isinstance(condition.function, list)): - for function in condition.function: - header.append(f'{condition_name}{function.__name__}') - else: - header.append(f'{condition_name}') + if hasattr(condition, 'function'): + if isinstance(condition.function, list): + for function in condition.function: + header.append(f'{condition_name}{function.__name__}') + + continue + + header.append(f'{condition_name}') while True: + losses = [] for condition_name in self.problem.conditions: condition = self.problem.conditions[condition_name] - if hasattr(condition, 'function'): - pts = self.input_pts[condition_name] - predicted = self.model(pts) - for function in condition.function: - residuals = function(pts, predicted) + for batch in data_loader[condition_name]: + + single_loss = [] + + if hasattr(condition, 'function'): + pts = batch[condition_name] + pts = pts.to(dtype=self.dtype, device=self.device) + pts.requires_grad_(True) + pts.retain_grad() + + predicted = self.model(pts) + for function in condition.function: + residuals = function(pts, predicted) + local_loss = ( + condition.data_weight*self._compute_norm( + residuals)) + single_loss.append(local_loss) + elif hasattr(condition, 'output_points'): + pts = condition.input_points.to( + dtype=self.dtype, device=self.device) + predicted = self.model(pts) + residuals = predicted - condition.output_points local_loss = ( - condition.data_weight*self._compute_norm( - residuals)) - losses.append(local_loss) - elif hasattr(condition, 'output_points'): - pts = condition.input_points - predicted = self.model(pts) - residuals = predicted - condition.output_points - local_loss = ( - condition.data_weight*self._compute_norm(residuals)) - losses.append(local_loss) - - self.optimizer.zero_grad() - - sum(losses).backward() - self.optimizer.step() + condition.data_weight*self._compute_norm(residuals)) + single_loss.append(local_loss) + + self.optimizer.zero_grad() + sum(single_loss).backward() + self.optimizer.step() + + losses.append(sum(single_loss)) if save_loss and (epoch % save_loss == 0 or epoch == 0): self.history_loss[epoch] = [ @@ -221,7 +277,8 @@ def train(self, stop=100, frequency_print=2, save_loss=1, trial=None): if isinstance(stop, int): if epoch == stop: - print('[epoch {:05d}] {:.6e} '.format(self.trained_epoch, sum(losses).item()), end='') + print('[epoch {:05d}] {:.6e} '.format( + self.trained_epoch, sum(losses).item()), end='') for loss in losses: print('{:.6e} '.format(loss.item()), end='') print() @@ -236,7 +293,8 @@ def train(self, stop=100, frequency_print=2, save_loss=1, trial=None): print('{:12.12s} '.format(name), end='') print() - print('[epoch {:05d}] {:.6e} '.format(self.trained_epoch, sum(losses).item()), end='') + print('[epoch {:05d}] {:.6e} '.format( + self.trained_epoch, sum(losses).item()), end='') for loss in losses: print('{:.6e} '.format(loss.item()), end='') print() @@ -246,7 +304,6 @@ def train(self, stop=100, frequency_print=2, save_loss=1, trial=None): return sum(losses).item() - def error(self, dtype='l2', res=100): import numpy as np @@ -261,7 +318,8 @@ def error(self, dtype='l2', res=100): grids_container = self.problem.data_solution['grid'] Z_true = self.problem.data_solution['grid_solution'] try: - unrolled_pts = torch.tensor([t.flatten() for t in grids_container]).T.to(dtype=self.dtype, device=self.device) + unrolled_pts = torch.tensor([t.flatten() for t in grids_container]).T.to( + dtype=self.dtype, device=self.device) Z_pred = self.model(unrolled_pts) Z_pred = Z_pred.detach().numpy().reshape(grids_container[0].shape) @@ -273,4 +331,5 @@ def error(self, dtype='l2', res=100): except: print("") print("Something went wrong...") - print("Not able to compute the error. Please pass a data solution or a true solution") + print( + "Not able to compute the error. Please pass a data solution or a true solution") diff --git a/pina/utils.py b/pina/utils.py index ed18af084..9d380ac7d 100644 --- a/pina/utils.py +++ b/pina/utils.py @@ -1,8 +1,10 @@ """Utils module""" from functools import reduce +import torch +from torch.utils.data import DataLoader, default_collate, ConcatDataset -def number_parameters(model, aggregate=True, only_trainable=True): #TODO: check +def number_parameters(model, aggregate=True, only_trainable=True): # TODO: check """ Return the number of parameters of a given `model`. @@ -41,5 +43,67 @@ def merge_two_tensors(tensor1, tensor2): tensor1 = LabelTensor(tensor1.repeat(n2, 1), labels=tensor1.labels) tensor2 = LabelTensor(tensor2.repeat_interleave(n1, dim=0), - labels=tensor2.labels) + labels=tensor2.labels) return tensor1.append(tensor2) + + +class PinaDataset(): + + def __init__(self, pinn) -> None: + self.pinn = pinn + + @property + def dataloader(self): + return self._create_dataloader() + + @property + def dataset(self): + return [self.SampleDataset(key, val) + for key, val in self.input_pts.items()] + + def _create_dataloader(self): + """Private method for creating dataloader + + :return: dataloader + :rtype: torch.utils.data.DataLoader + """ + if self.pinn.batch_size is None: + return {key: [{key: val}] for key, val in self.pinn.input_pts.items()} + + def custom_collate(batch): + # extracting pts labels + _, pts = list(batch[0].items())[0] + labels = pts.labels + # calling default torch collate + collate_res = default_collate(batch) + # save collate result in dict + res = {} + for key, val in collate_res.items(): + val.labels = labels + res[key] = val + return res + + # creating dataset, list of dataset for each location + datasets = [self.SampleDataset(key, val) + for key, val in self.pinn.input_pts.items()] + # creating dataloader + dataloaders = [DataLoader(dataset=dat, + batch_size=self.pinn.batch_size, + collate_fn=custom_collate) + for dat in datasets] + + return dict(zip(self.pinn.input_pts.keys(), dataloaders)) + + class SampleDataset(torch.utils.data.Dataset): + + def __init__(self, location, tensor): + self._tensor = tensor + self._location = location + self._len = len(tensor) + + def __getitem__(self, index): + tensor = self._tensor.select(0, index) + return {self._location: tensor} + + def __len__(self): + return self._len diff --git a/tests/test_pinn.py b/tests/test_pinn.py index f4b5ce890..a054d9f5e 100644 --- a/tests/test_pinn.py +++ b/tests/test_pinn.py @@ -31,19 +31,22 @@ def nil_dirichlet(input_, output_): def poisson_sol(self, pts): return -( - torch.sin(pts.extract(['x'])*torch.pi)* + torch.sin(pts.extract(['x'])*torch.pi) * torch.sin(pts.extract(['y'])*torch.pi) )/(2*torch.pi**2) truth_solution = poisson_sol + problem = Poisson() model = FeedForward(problem.input_variables, problem.output_variables) + def test_constructor(): PINN(problem, model) + def test_span_pts(): pinn = PINN(problem, model) n = 10 @@ -60,6 +63,7 @@ def test_span_pts(): pinn.span_pts(n, 'random', locations=['D']) assert pinn.input_pts['D'].shape[0] == n + def test_train(): pinn = PINN(problem, model) boundaries = ['gamma1', 'gamma2', 'gamma3', 'gamma4'] @@ -68,6 +72,7 @@ def test_train(): pinn.span_pts(n, 'grid', locations=['D']) pinn.train(5) + def test_train(): boundaries = ['gamma1', 'gamma2', 'gamma3', 'gamma4'] n = 10 @@ -78,4 +83,45 @@ def test_train(): pinn.span_pts(n, 'grid', boundaries) pinn.span_pts(n, 'grid', locations=['D']) pinn.train(50, save_loss=i) - assert list(pinn.history_loss.keys()) == truth_key \ No newline at end of file + assert list(pinn.history_loss.keys()) == truth_key + + +def test_train_batch(): + pinn = PINN(problem, model, batch_size=6) + boundaries = ['gamma1', 'gamma2', 'gamma3', 'gamma4'] + n = 10 + pinn.span_pts(n, 'grid', boundaries) + pinn.span_pts(n, 'grid', locations=['D']) + pinn.train(5) + + +def test_train_batch(): + boundaries = ['gamma1', 'gamma2', 'gamma3', 'gamma4'] + n = 10 + expected_keys = [[], list(range(0, 50, 3))] + param = [0, 3] + for i, truth_key in zip(param, expected_keys): + pinn = PINN(problem, model, batch_size=6) + pinn.span_pts(n, 'grid', boundaries) + pinn.span_pts(n, 'grid', locations=['D']) + pinn.train(50, save_loss=i) + assert list(pinn.history_loss.keys()) == truth_key + + +if torch.cuda.is_available(): + + def test_gpu_train(): + pinn = PINN(problem, model, batch_size=20, device='cuda') + boundaries = ['gamma1', 'gamma2', 'gamma3', 'gamma4'] + n = 100 + pinn.span_pts(n, 'grid', boundaries) + pinn.span_pts(n, 'grid', locations=['D']) + pinn.train(5) + + def test_gpu_train_nobatch(): + pinn = PINN(problem, model, batch_size=None, device='cuda') + boundaries = ['gamma1', 'gamma2', 'gamma3', 'gamma4'] + n = 100 + pinn.span_pts(n, 'grid', boundaries) + pinn.span_pts(n, 'grid', locations=['D']) + pinn.train(5)