diff --git a/src/deepymod/model/library.py b/src/deepymod/model/library.py index 5d76bb841..d452c6e1c 100644 --- a/src/deepymod/model/library.py +++ b/src/deepymod/model/library.py @@ -1,7 +1,7 @@ """ Contains the library classes that store the parameters u_t, theta""" import numpy as np import torch -from torch.autograd import grad +from torch import autograd from itertools import combinations from functools import reduce from .deepmod import Library @@ -20,17 +20,16 @@ def library_poly(prediction: torch.Tensor, max_order: int) -> torch.Tensor: Returns: torch.Tensor: Tensor with polynomials (n_samples, max_order + 1) """ - u = torch.ones_like(prediction) - for order in np.arange(1, max_order + 1): - u = torch.cat((u, u[:, order - 1 : order] * prediction), dim=1) + polynomials = [prediction ** order for order in torch.arange(1, max_order + 1)] + u = torch.cat([torch.ones_like(prediction)] + polynomials, dim=1) return u -def library_deriv( - data: torch.Tensor, prediction: torch.Tensor, max_order: int +def derivs( + prediction: torch.Tensor, coordinates: torch.Tensor, max_order: int ) -> Tuple[torch.Tensor, torch.Tensor]: - """Given a prediction u evaluated at data (t, x), returns du/dt and du/dx up to max_order, including ones + """Given a prediction u evaluated at coordinates (t, x), returns du/dt and du/dx up to max_order, including ones as first column. Args: @@ -41,31 +40,21 @@ def library_deriv( Returns: Tuple[torch.Tensor, torch.Tensor]: time derivative and feature library ((n_samples, 1), (n_samples, max_order + 1)) """ - dy = grad( - prediction, data, grad_outputs=torch.ones_like(prediction), create_graph=True + assert max_order > 0, "Only 1st order and up allowed." + + grad = lambda f: autograd.grad( + f, coordinates, grad_outputs=torch.ones_like(f), create_graph=True )[0] - time_deriv = dy[:, 0:1] - - if max_order == 0: - du = torch.ones_like(time_deriv) - else: - du = torch.cat((torch.ones_like(time_deriv), dy[:, 1:2]), dim=1) - if max_order > 1: - for order in np.arange(1, max_order): - du = torch.cat( - ( - du, - grad( - du[:, order : order + 1], - data, - grad_outputs=torch.ones_like(prediction), - create_graph=True, - )[0][:, 1:2], - ), - dim=1, - ) - - return time_deriv, du + + df = grad(prediction) + time_derivs, dx = df[:, [0]], df[:, [1]] + + du = [torch.ones_like(prediction), dx] + for order in np.arange(1, max_order): + du.append(grad(du[order])[:, [1]]) + space_derivs = torch.cat(du, dim=1) + + return time_derivs, space_derivs # ========================= Actual library functions ======================== @@ -103,50 +92,64 @@ def library( Tuple[TensorList, TensorList]: The time derivatives [(n_samples, 1) x n_outputs] and the thetas [(n_samples, (poly_order + 1)(deriv_order + 1))] computed from the library and data. """ - prediction, data = input - poly_list = [] - deriv_list = [] - time_deriv_list = [] + prediction, coordinates = input + time_derivs, space_derivs = self.derivative_features( + prediction, coordinates, self.diff_order + ) + thetas = self.build_features(prediction, space_derivs, self.poly_order) - # Creating lists for all outputs - for output in np.arange(prediction.shape[1]): - time_deriv, du = library_deriv( - data, prediction[:, output : output + 1], self.diff_order - ) - u = library_poly(prediction[:, output : output + 1], self.poly_order) + return time_derivs, thetas - poly_list.append(u) - deriv_list.append(du) - time_deriv_list.append(time_deriv) + @staticmethod + def derivative_features( + prediction: torch.Tensor, coordinates: torch.Tensor, diff_order: int + ) -> Tuple[TensorList, TensorList]: - samples = time_deriv_list[0].shape[0] - total_terms = poly_list[0].shape[1] * deriv_list[0].shape[1] + # Calculate derivs over all outputs + n_outputs = prediction.shape[1] + df = [ + derivs(prediction[:, [output]], coordinates, diff_order) + for output in np.arange(n_outputs) + ] + # Unzip to separate time and space + time_derivs, space_derivs = map(list, zip(*df)) + return time_derivs, space_derivs + + @staticmethod + def build_features( + prediction: torch.Tensor, space_derivs: torch.Tensor, poly_order: int + ) -> TensorList: + + n_samples, n_outputs = prediction.shape + total_terms = (poly_order + 1) * space_derivs[0].shape[1] + + # Creating lists for all outputs + poly_list = [ + library_poly(prediction[:, [output]], poly_order) + for output in np.arange(n_outputs) + ] # Calculating theta - if len(poly_list) == 1: + if n_outputs == 1: # If we have a single output, we simply calculate and flatten matrix product # between polynomials and derivatives to get library theta = torch.matmul( - poly_list[0][:, :, None], deriv_list[0][:, None, :] - ).view(samples, total_terms) + poly_list[0][:, :, None], space_derivs[0][:, None, :] + ).view(n_samples, total_terms) + else: - theta_uv = reduce( - (lambda x, y: (x[:, :, None] @ y[:, None, :]).view(samples, -1)), + uv = reduce( + (lambda x, y: (x[:, :, None] @ y[:, None, :]).view(n_samples, -1)), poly_list, ) # calculate all unique combinations of derivatives - theta_dudv = torch.cat( - [ - torch.matmul(du[:, :, None], dv[:, None, :]).view(samples, -1)[ - :, 1: - ] - for du, dv in combinations(deriv_list, 2) - ], - 1, - ) - theta = torch.cat([theta_uv, theta_dudv], dim=1) - - return time_deriv_list, [theta] + dudv = [ + torch.matmul(du[:, :, None], dv[:, None, :]).view(n_samples, -1)[:, 1:] + for du, dv in combinations(space_derivs, 2) + ] + dudv = torch.cat(dudv, dim=1) + theta = torch.cat([uv, dudv], dim=1) + return [theta] class Library2D(Library): @@ -180,7 +183,7 @@ def library( u = torch.cat((u, u[:, order - 1 : order] * prediction), dim=1) # Gradients - du = grad( + du = autograd.grad( prediction, data, grad_outputs=torch.ones_like(prediction), @@ -189,12 +192,12 @@ def library( u_t = du[:, 0:1] u_x = du[:, 1:2] u_y = du[:, 2:3] - du2 = grad( + du2 = autograd.grad( u_x, data, grad_outputs=torch.ones_like(prediction), create_graph=True )[0] u_xx = du2[:, 1:2] u_xy = du2[:, 2:3] - u_yy = grad( + u_yy = autograd.grad( u_y, data, grad_outputs=torch.ones_like(prediction), create_graph=True )[0][:, 2:3]