3
3
4
4
from .problem import AbstractProblem
5
5
from .label_tensor import LabelTensor
6
- from .utils import merge_tensors
6
+ from .utils import merge_tensors , PinaDataset
7
7
8
- torch .pi = torch .acos (torch .zeros (1 )).item () * 2 # which is 3.1415927410125732
8
+
9
+ torch .pi = torch .acos (torch .zeros (1 )).item () * 2 # which is 3.1415927410125732
9
10
10
11
11
12
class PINN (object ):
12
13
13
14
def __init__ (self ,
14
- problem ,
15
- model ,
16
- optimizer = torch .optim .Adam ,
17
- lr = 0.001 ,
18
- regularizer = 0.00001 ,
19
- dtype = torch .float32 ,
20
- device = 'cpu' ,
21
- error_norm = 'mse' ):
15
+ problem ,
16
+ model ,
17
+ optimizer = torch .optim .Adam ,
18
+ lr = 0.001 ,
19
+ regularizer = 0.00001 ,
20
+ batch_size = None ,
21
+ dtype = torch .float32 ,
22
+ device = 'cpu' ,
23
+ error_norm = 'mse' ):
22
24
'''
23
25
:param Problem problem: the formualation of the problem.
24
26
:param torch.nn.Module model: the neural network model to use.
27
+ :param torch.optim optimizer: the neural network optimizer to use;
28
+ default is `torch.optim.Adam`.
25
29
:param float lr: the learning rate; default is 0.001.
26
30
:param float regularizer: the coefficient for L2 regularizer term.
27
31
:param type dtype: the data type to use for the model. Valid option are
28
32
`torch.float32` and `torch.float64` (`torch.float16` only on GPU);
29
33
default is `torch.float64`.
34
+ :param string device: the device used for training; default 'cpu'
35
+ option include 'cuda' if cuda is available.
36
+ :param string/int error_norm: the loss function used as minimizer,
37
+ default mean square error 'mse'. If string options include mean
38
+ error 'me' and mean square error 'mse'. If int, the p-norm is
39
+ calculated where p is specifined by the int input.
40
+ :param int batch_size: batch size for the dataloader; default 5.
30
41
'''
31
42
32
43
if dtype == torch .float64 :
@@ -38,7 +49,7 @@ def __init__(self,
38
49
# self._architecture['input_dimension'] = self.problem.domain_bound.shape[0]
39
50
# self._architecture['output_dimension'] = len(self.problem.variables)
40
51
# if hasattr(self.problem, 'params_domain'):
41
- # self._architecture['input_dimension'] += self.problem.params_domain.shape[0]
52
+ # self._architecture['input_dimension'] += self.problem.params_domain.shape[0]
42
53
43
54
self .error_norm = error_norm
44
55
@@ -59,6 +70,9 @@ def __init__(self,
59
70
self .optimizer = optimizer (
60
71
self .model .parameters (), lr = lr , weight_decay = regularizer )
61
72
73
+ self .batch_size = batch_size
74
+ self .data_set = PinaDataset (self )
75
+
62
76
@property
63
77
def problem (self ):
64
78
return self ._problem
@@ -79,7 +93,7 @@ def _compute_norm(self, vec):
79
93
:param vec torch.tensor: the tensor
80
94
"""
81
95
if isinstance (self .error_norm , int ):
82
- return torch .linalg .vector_norm (vec , ord = self .error_norm , dtype = self .dytpe )
96
+ return torch .linalg .vector_norm (vec , ord = self .error_norm , dtype = self .dytpe )
83
97
elif self .error_norm == 'mse' :
84
98
return torch .mean (vec .pow (2 ))
85
99
elif self .error_norm == 'me' :
@@ -90,16 +104,16 @@ def _compute_norm(self, vec):
90
104
def save_state (self , filename ):
91
105
92
106
checkpoint = {
93
- 'epoch' : self .trained_epoch ,
94
- 'model_state' : self .model .state_dict (),
95
- 'optimizer_state' : self .optimizer .state_dict (),
96
- 'optimizer_class' : self .optimizer .__class__ ,
97
- 'history' : self .history_loss ,
98
- 'input_points_dict' : self .input_pts ,
107
+ 'epoch' : self .trained_epoch ,
108
+ 'model_state' : self .model .state_dict (),
109
+ 'optimizer_state' : self .optimizer .state_dict (),
110
+ 'optimizer_class' : self .optimizer .__class__ ,
111
+ 'history' : self .history_loss ,
112
+ 'input_points_dict' : self .input_pts ,
99
113
}
100
114
101
115
# TODO save also architecture param?
102
- #if isinstance(self.model, DeepFeedForward):
116
+ # if isinstance(self.model, DeepFeedForward):
103
117
# checkpoint['model_class'] = self.model.__class__
104
118
# checkpoint['model_structure'] = {
105
119
# }
@@ -110,7 +124,6 @@ def load_state(self, filename):
110
124
checkpoint = torch .load (filename )
111
125
self .model .load_state_dict (checkpoint ['model_state' ])
112
126
113
-
114
127
self .optimizer = checkpoint ['optimizer_class' ](self .model .parameters ())
115
128
self .optimizer .load_state_dict (checkpoint ['optimizer_state' ])
116
129
@@ -121,6 +134,39 @@ def load_state(self, filename):
121
134
122
135
return self
123
136
137
+ def _create_dataloader (self ):
138
+ """Private method for creating dataloader
139
+
140
+ :return: dataloader
141
+ :rtype: torch.utils.data.DataLoader
142
+ """
143
+ if self .batch_size is None :
144
+ return [self .input_pts ]
145
+
146
+ def custom_collate (batch ):
147
+ # extracting pts labels
148
+ _ , pts = list (batch [0 ].items ())[0 ]
149
+ labels = pts .labels
150
+ # calling default torch collate
151
+ collate_res = default_collate (batch )
152
+ # save collate result in dict
153
+ res = {}
154
+ for key , val in collate_res .items ():
155
+ val .labels = labels
156
+ res [key ] = val
157
+ return res
158
+
159
+ # creating dataset, list of dataset for each location
160
+ datasets = [MyDataSet (key , val )
161
+ for key , val in self .input_pts .items ()]
162
+ # creating dataloader
163
+ dataloaders = [DataLoader (dataset = dat ,
164
+ batch_size = self .batch_size ,
165
+ collate_fn = custom_collate )
166
+ for dat in datasets ]
167
+
168
+ return dict (zip (self .input_pts .keys (), dataloaders ))
169
+
124
170
def span_pts (self , * args , ** kwargs ):
125
171
"""
126
172
>>> pinn.span_pts(n=10, mode='grid')
@@ -155,59 +201,69 @@ def span_pts(self, *args, **kwargs):
155
201
argument ['n' ],
156
202
argument ['mode' ],
157
203
variables = argument ['variables' ])
158
- for argument in arguments )
204
+ for argument in arguments )
159
205
pts = merge_tensors (samples )
160
206
161
207
# TODO
162
208
# pts = pts.double()
163
- pts = pts .to (dtype = self .dtype , device = self .device )
164
- pts .requires_grad_ (True )
165
- pts .retain_grad ()
166
-
167
209
self .input_pts [location ] = pts
168
210
169
211
def train (self , stop = 100 , frequency_print = 2 , save_loss = 1 , trial = None ):
170
212
171
213
epoch = 0
214
+ data_loader = self .data_set .dataloader
172
215
173
216
header = []
174
217
for condition_name in self .problem .conditions :
175
218
condition = self .problem .conditions [condition_name ]
176
219
177
- if (hasattr (condition , 'function' ) and
178
- isinstance (condition .function , list )):
179
- for function in condition .function :
180
- header .append (f'{ condition_name } { function .__name__ } ' )
181
- else :
182
- header .append (f'{ condition_name } ' )
220
+ if hasattr (condition , 'function' ):
221
+ if isinstance (condition .function , list ):
222
+ for function in condition .function :
223
+ header .append (f'{ condition_name } { function .__name__ } ' )
224
+
225
+ continue
226
+
227
+ header .append (f'{ condition_name } ' )
183
228
184
229
while True :
230
+
185
231
losses = []
186
232
187
233
for condition_name in self .problem .conditions :
188
234
condition = self .problem .conditions [condition_name ]
189
235
190
- if hasattr (condition , 'function' ):
191
- pts = self .input_pts [condition_name ]
192
- predicted = self .model (pts )
193
- for function in condition .function :
194
- residuals = function (pts , predicted )
236
+ for batch in data_loader [condition_name ]:
237
+
238
+ single_loss = []
239
+
240
+ if hasattr (condition , 'function' ):
241
+ pts = batch [condition_name ]
242
+ pts = pts .to (dtype = self .dtype , device = self .device )
243
+ pts .requires_grad_ (True )
244
+ pts .retain_grad ()
245
+
246
+ predicted = self .model (pts )
247
+ for function in condition .function :
248
+ residuals = function (pts , predicted )
249
+ local_loss = (
250
+ condition .data_weight * self ._compute_norm (
251
+ residuals ))
252
+ single_loss .append (local_loss )
253
+ elif hasattr (condition , 'output_points' ):
254
+ pts = condition .input_points .to (
255
+ dtype = self .dtype , device = self .device )
256
+ predicted = self .model (pts )
257
+ residuals = predicted - condition .output_points
195
258
local_loss = (
196
- condition .data_weight * self ._compute_norm (
197
- residuals ))
198
- losses .append (local_loss )
199
- elif hasattr (condition , 'output_points' ):
200
- pts = condition .input_points
201
- predicted = self .model (pts )
202
- residuals = predicted - condition .output_points
203
- local_loss = (
204
- condition .data_weight * self ._compute_norm (residuals ))
205
- losses .append (local_loss )
206
-
207
- self .optimizer .zero_grad ()
208
-
209
- sum (losses ).backward ()
210
- self .optimizer .step ()
259
+ condition .data_weight * self ._compute_norm (residuals ))
260
+ single_loss .append (local_loss )
261
+
262
+ self .optimizer .zero_grad ()
263
+ sum (single_loss ).backward ()
264
+ self .optimizer .step ()
265
+
266
+ losses .append (sum (single_loss ))
211
267
212
268
if save_loss and (epoch % save_loss == 0 or epoch == 0 ):
213
269
self .history_loss [epoch ] = [
@@ -221,7 +277,8 @@ def train(self, stop=100, frequency_print=2, save_loss=1, trial=None):
221
277
222
278
if isinstance (stop , int ):
223
279
if epoch == stop :
224
- print ('[epoch {:05d}] {:.6e} ' .format (self .trained_epoch , sum (losses ).item ()), end = '' )
280
+ print ('[epoch {:05d}] {:.6e} ' .format (
281
+ self .trained_epoch , sum (losses ).item ()), end = '' )
225
282
for loss in losses :
226
283
print ('{:.6e} ' .format (loss .item ()), end = '' )
227
284
print ()
@@ -236,7 +293,8 @@ def train(self, stop=100, frequency_print=2, save_loss=1, trial=None):
236
293
print ('{:12.12s} ' .format (name ), end = '' )
237
294
print ()
238
295
239
- print ('[epoch {:05d}] {:.6e} ' .format (self .trained_epoch , sum (losses ).item ()), end = '' )
296
+ print ('[epoch {:05d}] {:.6e} ' .format (
297
+ self .trained_epoch , sum (losses ).item ()), end = '' )
240
298
for loss in losses :
241
299
print ('{:.6e} ' .format (loss .item ()), end = '' )
242
300
print ()
@@ -246,7 +304,6 @@ def train(self, stop=100, frequency_print=2, save_loss=1, trial=None):
246
304
247
305
return sum (losses ).item ()
248
306
249
-
250
307
def error (self , dtype = 'l2' , res = 100 ):
251
308
252
309
import numpy as np
@@ -261,7 +318,8 @@ def error(self, dtype='l2', res=100):
261
318
grids_container = self .problem .data_solution ['grid' ]
262
319
Z_true = self .problem .data_solution ['grid_solution' ]
263
320
try :
264
- unrolled_pts = torch .tensor ([t .flatten () for t in grids_container ]).T .to (dtype = self .dtype , device = self .device )
321
+ unrolled_pts = torch .tensor ([t .flatten () for t in grids_container ]).T .to (
322
+ dtype = self .dtype , device = self .device )
265
323
Z_pred = self .model (unrolled_pts )
266
324
Z_pred = Z_pred .detach ().numpy ().reshape (grids_container [0 ].shape )
267
325
@@ -273,4 +331,5 @@ def error(self, dtype='l2', res=100):
273
331
except :
274
332
print ("" )
275
333
print ("Something went wrong..." )
276
- print ("Not able to compute the error. Please pass a data solution or a true solution" )
334
+ print (
335
+ "Not able to compute the error. Please pass a data solution or a true solution" )
0 commit comments