Skip to content

Commit c4e39b5

Browse files
committed
batch_enhancement (#51)
1 parent 05131cd commit c4e39b5

File tree

4 files changed

+238
-61
lines changed

4 files changed

+238
-61
lines changed

pina/label_tensor.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@ def labels(self):
7979

8080
@labels.setter
8181
def labels(self, labels):
82-
if len(labels) != self.shape[1]: # small check
82+
if len(labels) != self.shape[1]: # small check
8383
raise ValueError(
8484
'the tensor has not the same number of columns of '
8585
'the passed labels.')
@@ -106,6 +106,14 @@ def to(self, *args, **kwargs):
106106
new.data = tmp.data
107107
return new
108108

109+
def select(self, *args, **kwargs):
110+
"""
111+
Performs Tensor selection. For more details, see :meth:`torch.Tensor.select`.
112+
"""
113+
tmp = super().select(*args, **kwargs)
114+
tmp._labels = self._labels
115+
return tmp
116+
109117
def extract(self, label_to_extract):
110118
"""
111119
Extract the subset of the original tensor by returning all the columns

pina/pinn.py

Lines changed: 115 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -3,30 +3,41 @@
33

44
from .problem import AbstractProblem
55
from .label_tensor import LabelTensor
6-
from .utils import merge_tensors
6+
from .utils import merge_tensors, PinaDataset
77

8-
torch.pi = torch.acos(torch.zeros(1)).item() * 2 # which is 3.1415927410125732
8+
9+
torch.pi = torch.acos(torch.zeros(1)).item() * 2 # which is 3.1415927410125732
910

1011

1112
class PINN(object):
1213

1314
def __init__(self,
14-
problem,
15-
model,
16-
optimizer=torch.optim.Adam,
17-
lr=0.001,
18-
regularizer=0.00001,
19-
dtype=torch.float32,
20-
device='cpu',
21-
error_norm='mse'):
15+
problem,
16+
model,
17+
optimizer=torch.optim.Adam,
18+
lr=0.001,
19+
regularizer=0.00001,
20+
batch_size=None,
21+
dtype=torch.float32,
22+
device='cpu',
23+
error_norm='mse'):
2224
'''
2325
:param Problem problem: the formualation of the problem.
2426
:param torch.nn.Module model: the neural network model to use.
27+
:param torch.optim optimizer: the neural network optimizer to use;
28+
default is `torch.optim.Adam`.
2529
:param float lr: the learning rate; default is 0.001.
2630
:param float regularizer: the coefficient for L2 regularizer term.
2731
:param type dtype: the data type to use for the model. Valid option are
2832
`torch.float32` and `torch.float64` (`torch.float16` only on GPU);
2933
default is `torch.float64`.
34+
:param string device: the device used for training; default 'cpu'
35+
option include 'cuda' if cuda is available.
36+
:param string/int error_norm: the loss function used as minimizer,
37+
default mean square error 'mse'. If string options include mean
38+
error 'me' and mean square error 'mse'. If int, the p-norm is
39+
calculated where p is specifined by the int input.
40+
:param int batch_size: batch size for the dataloader; default 5.
3041
'''
3142

3243
if dtype == torch.float64:
@@ -38,7 +49,7 @@ def __init__(self,
3849
# self._architecture['input_dimension'] = self.problem.domain_bound.shape[0]
3950
# self._architecture['output_dimension'] = len(self.problem.variables)
4051
# if hasattr(self.problem, 'params_domain'):
41-
# self._architecture['input_dimension'] += self.problem.params_domain.shape[0]
52+
# self._architecture['input_dimension'] += self.problem.params_domain.shape[0]
4253

4354
self.error_norm = error_norm
4455

@@ -59,6 +70,9 @@ def __init__(self,
5970
self.optimizer = optimizer(
6071
self.model.parameters(), lr=lr, weight_decay=regularizer)
6172

73+
self.batch_size = batch_size
74+
self.data_set = PinaDataset(self)
75+
6276
@property
6377
def problem(self):
6478
return self._problem
@@ -79,7 +93,7 @@ def _compute_norm(self, vec):
7993
:param vec torch.tensor: the tensor
8094
"""
8195
if isinstance(self.error_norm, int):
82-
return torch.linalg.vector_norm(vec, ord = self.error_norm, dtype=self.dytpe)
96+
return torch.linalg.vector_norm(vec, ord=self.error_norm, dtype=self.dytpe)
8397
elif self.error_norm == 'mse':
8498
return torch.mean(vec.pow(2))
8599
elif self.error_norm == 'me':
@@ -90,16 +104,16 @@ def _compute_norm(self, vec):
90104
def save_state(self, filename):
91105

92106
checkpoint = {
93-
'epoch': self.trained_epoch,
94-
'model_state': self.model.state_dict(),
95-
'optimizer_state' : self.optimizer.state_dict(),
96-
'optimizer_class' : self.optimizer.__class__,
97-
'history' : self.history_loss,
98-
'input_points_dict' : self.input_pts,
107+
'epoch': self.trained_epoch,
108+
'model_state': self.model.state_dict(),
109+
'optimizer_state': self.optimizer.state_dict(),
110+
'optimizer_class': self.optimizer.__class__,
111+
'history': self.history_loss,
112+
'input_points_dict': self.input_pts,
99113
}
100114

101115
# TODO save also architecture param?
102-
#if isinstance(self.model, DeepFeedForward):
116+
# if isinstance(self.model, DeepFeedForward):
103117
# checkpoint['model_class'] = self.model.__class__
104118
# checkpoint['model_structure'] = {
105119
# }
@@ -110,7 +124,6 @@ def load_state(self, filename):
110124
checkpoint = torch.load(filename)
111125
self.model.load_state_dict(checkpoint['model_state'])
112126

113-
114127
self.optimizer = checkpoint['optimizer_class'](self.model.parameters())
115128
self.optimizer.load_state_dict(checkpoint['optimizer_state'])
116129

@@ -121,6 +134,39 @@ def load_state(self, filename):
121134

122135
return self
123136

137+
def _create_dataloader(self):
138+
"""Private method for creating dataloader
139+
140+
:return: dataloader
141+
:rtype: torch.utils.data.DataLoader
142+
"""
143+
if self.batch_size is None:
144+
return [self.input_pts]
145+
146+
def custom_collate(batch):
147+
# extracting pts labels
148+
_, pts = list(batch[0].items())[0]
149+
labels = pts.labels
150+
# calling default torch collate
151+
collate_res = default_collate(batch)
152+
# save collate result in dict
153+
res = {}
154+
for key, val in collate_res.items():
155+
val.labels = labels
156+
res[key] = val
157+
return res
158+
159+
# creating dataset, list of dataset for each location
160+
datasets = [MyDataSet(key, val)
161+
for key, val in self.input_pts.items()]
162+
# creating dataloader
163+
dataloaders = [DataLoader(dataset=dat,
164+
batch_size=self.batch_size,
165+
collate_fn=custom_collate)
166+
for dat in datasets]
167+
168+
return dict(zip(self.input_pts.keys(), dataloaders))
169+
124170
def span_pts(self, *args, **kwargs):
125171
"""
126172
>>> pinn.span_pts(n=10, mode='grid')
@@ -155,59 +201,69 @@ def span_pts(self, *args, **kwargs):
155201
argument['n'],
156202
argument['mode'],
157203
variables=argument['variables'])
158-
for argument in arguments)
204+
for argument in arguments)
159205
pts = merge_tensors(samples)
160206

161207
# TODO
162208
# pts = pts.double()
163-
pts = pts.to(dtype=self.dtype, device=self.device)
164-
pts.requires_grad_(True)
165-
pts.retain_grad()
166-
167209
self.input_pts[location] = pts
168210

169211
def train(self, stop=100, frequency_print=2, save_loss=1, trial=None):
170212

171213
epoch = 0
214+
data_loader = self.data_set.dataloader
172215

173216
header = []
174217
for condition_name in self.problem.conditions:
175218
condition = self.problem.conditions[condition_name]
176219

177-
if (hasattr(condition, 'function') and
178-
isinstance(condition.function, list)):
179-
for function in condition.function:
180-
header.append(f'{condition_name}{function.__name__}')
181-
else:
182-
header.append(f'{condition_name}')
220+
if hasattr(condition, 'function'):
221+
if isinstance(condition.function, list):
222+
for function in condition.function:
223+
header.append(f'{condition_name}{function.__name__}')
224+
225+
continue
226+
227+
header.append(f'{condition_name}')
183228

184229
while True:
230+
185231
losses = []
186232

187233
for condition_name in self.problem.conditions:
188234
condition = self.problem.conditions[condition_name]
189235

190-
if hasattr(condition, 'function'):
191-
pts = self.input_pts[condition_name]
192-
predicted = self.model(pts)
193-
for function in condition.function:
194-
residuals = function(pts, predicted)
236+
for batch in data_loader[condition_name]:
237+
238+
single_loss = []
239+
240+
if hasattr(condition, 'function'):
241+
pts = batch[condition_name]
242+
pts = pts.to(dtype=self.dtype, device=self.device)
243+
pts.requires_grad_(True)
244+
pts.retain_grad()
245+
246+
predicted = self.model(pts)
247+
for function in condition.function:
248+
residuals = function(pts, predicted)
249+
local_loss = (
250+
condition.data_weight*self._compute_norm(
251+
residuals))
252+
single_loss.append(local_loss)
253+
elif hasattr(condition, 'output_points'):
254+
pts = condition.input_points.to(
255+
dtype=self.dtype, device=self.device)
256+
predicted = self.model(pts)
257+
residuals = predicted - condition.output_points
195258
local_loss = (
196-
condition.data_weight*self._compute_norm(
197-
residuals))
198-
losses.append(local_loss)
199-
elif hasattr(condition, 'output_points'):
200-
pts = condition.input_points
201-
predicted = self.model(pts)
202-
residuals = predicted - condition.output_points
203-
local_loss = (
204-
condition.data_weight*self._compute_norm(residuals))
205-
losses.append(local_loss)
206-
207-
self.optimizer.zero_grad()
208-
209-
sum(losses).backward()
210-
self.optimizer.step()
259+
condition.data_weight*self._compute_norm(residuals))
260+
single_loss.append(local_loss)
261+
262+
self.optimizer.zero_grad()
263+
sum(single_loss).backward()
264+
self.optimizer.step()
265+
266+
losses.append(sum(single_loss))
211267

212268
if save_loss and (epoch % save_loss == 0 or epoch == 0):
213269
self.history_loss[epoch] = [
@@ -221,7 +277,8 @@ def train(self, stop=100, frequency_print=2, save_loss=1, trial=None):
221277

222278
if isinstance(stop, int):
223279
if epoch == stop:
224-
print('[epoch {:05d}] {:.6e} '.format(self.trained_epoch, sum(losses).item()), end='')
280+
print('[epoch {:05d}] {:.6e} '.format(
281+
self.trained_epoch, sum(losses).item()), end='')
225282
for loss in losses:
226283
print('{:.6e} '.format(loss.item()), end='')
227284
print()
@@ -236,7 +293,8 @@ def train(self, stop=100, frequency_print=2, save_loss=1, trial=None):
236293
print('{:12.12s} '.format(name), end='')
237294
print()
238295

239-
print('[epoch {:05d}] {:.6e} '.format(self.trained_epoch, sum(losses).item()), end='')
296+
print('[epoch {:05d}] {:.6e} '.format(
297+
self.trained_epoch, sum(losses).item()), end='')
240298
for loss in losses:
241299
print('{:.6e} '.format(loss.item()), end='')
242300
print()
@@ -246,7 +304,6 @@ def train(self, stop=100, frequency_print=2, save_loss=1, trial=None):
246304

247305
return sum(losses).item()
248306

249-
250307
def error(self, dtype='l2', res=100):
251308

252309
import numpy as np
@@ -261,7 +318,8 @@ def error(self, dtype='l2', res=100):
261318
grids_container = self.problem.data_solution['grid']
262319
Z_true = self.problem.data_solution['grid_solution']
263320
try:
264-
unrolled_pts = torch.tensor([t.flatten() for t in grids_container]).T.to(dtype=self.dtype, device=self.device)
321+
unrolled_pts = torch.tensor([t.flatten() for t in grids_container]).T.to(
322+
dtype=self.dtype, device=self.device)
265323
Z_pred = self.model(unrolled_pts)
266324
Z_pred = Z_pred.detach().numpy().reshape(grids_container[0].shape)
267325

@@ -273,4 +331,5 @@ def error(self, dtype='l2', res=100):
273331
except:
274332
print("")
275333
print("Something went wrong...")
276-
print("Not able to compute the error. Please pass a data solution or a true solution")
334+
print(
335+
"Not able to compute the error. Please pass a data solution or a true solution")

0 commit comments

Comments
 (0)