Skip to content

Commit

Permalink
Add verbose to dde.model.train and dde.model.compile (#1879)
Browse files Browse the repository at this point in the history
  • Loading branch information
KangyuWeng authored Nov 7, 2024
1 parent 4c8b2d2 commit c816a11
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 23 deletions.
49 changes: 27 additions & 22 deletions deepxde/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ def compile(
decay=None,
loss_weights=None,
external_trainable_variables=None,
verbose=1,
):
"""Configures the model for training.
Expand Down Expand Up @@ -114,8 +115,9 @@ def compile(
physics systems that need to be recovered. If the backend is
tensorflow.compat.v1, `external_trainable_variables` is ignored, and all
trainable ``dde.Variable`` objects are automatically collected.
verbose (Integer): Controls the verbosity of the compile process.
"""
if config.rank == 0:
if verbose > 0 and config.rank == 0:
print("Compiling model...")
self.opt_name = optimizer
loss_fn = losses_module.get(loss)
Expand Down Expand Up @@ -585,6 +587,7 @@ def train(
model_restore_path=None,
model_save_path=None,
epochs=None,
verbose=1,
):
"""Trains the model.
Expand All @@ -610,6 +613,7 @@ def train(
model_save_path (String): Prefix of filenames created for the checkpoint.
epochs (Integer): Deprecated alias to `iterations`. This will be removed in
a future version.
verbose (Integer): Controls the verbosity of the train process.
"""
if iterations is None and epochs is not None:
print(
Expand All @@ -635,36 +639,36 @@ def train(
if model_restore_path is not None:
self.restore(model_restore_path, verbose=1)

if config.rank == 0:
if verbose > 0 and config.rank == 0:
print("Training model...\n")
self.stop_training = False
self.train_state.set_data_train(*self.data.train_next_batch(self.batch_size))
self.train_state.set_data_test(*self.data.test())
self._test()
self._test(verbose=verbose)
self.callbacks.on_train_begin()
if optimizers.is_external_optimizer(self.opt_name):
if backend_name == "tensorflow.compat.v1":
self._train_tensorflow_compat_v1_scipy(display_every)
self._train_tensorflow_compat_v1_scipy(display_every, verbose=verbose)
elif backend_name == "tensorflow":
self._train_tensorflow_tfp()
self._train_tensorflow_tfp(verbose=verbose)
elif backend_name == "pytorch":
self._train_pytorch_lbfgs()
self._train_pytorch_lbfgs(verbose=verbose)
elif backend_name == "paddle":
self._train_paddle_lbfgs()
self._train_paddle_lbfgs(verbose=verbose)
else:
if iterations is None:
raise ValueError("No iterations for {}.".format(self.opt_name))
self._train_sgd(iterations, display_every)
self._train_sgd(iterations, display_every, verbose=verbose)
self.callbacks.on_train_end()

if config.rank == 0:
if verbose > 0 and config.rank == 0:
print("")
display.training_display.summary(self.train_state)
if model_save_path is not None:
self.save(model_save_path, verbose=1)
return self.losshistory, self.train_state

def _train_sgd(self, iterations, display_every):
def _train_sgd(self, iterations, display_every, verbose=1):
for i in range(iterations):
self.callbacks.on_epoch_begin()
self.callbacks.on_batch_begin()
Expand All @@ -681,15 +685,15 @@ def _train_sgd(self, iterations, display_every):
self.train_state.epoch += 1
self.train_state.step += 1
if self.train_state.step % display_every == 0 or i + 1 == iterations:
self._test()
self._test(verbose=verbose)

self.callbacks.on_batch_end()
self.callbacks.on_epoch_end()

if self.stop_training:
break

def _train_tensorflow_compat_v1_scipy(self, display_every):
def _train_tensorflow_compat_v1_scipy(self, display_every, verbose=1):
def loss_callback(loss_train, loss_test, *args):
self.train_state.epoch += 1
self.train_state.step += 1
Expand All @@ -703,7 +707,8 @@ def loss_callback(loss_train, loss_test, *args):
self.train_state.loss_test,
None,
)
display.training_display(self.train_state)
if verbose > 0:
display.training_display(self.train_state)
for cb in self.callbacks.callbacks:
if type(cb).__name__ == "VariableValue":
cb.epochs_since_last += 1
Expand Down Expand Up @@ -736,9 +741,9 @@ def loss_callback(loss_train, loss_test, *args):
fetches=fetches,
loss_callback=loss_callback,
)
self._test()
self._test(verbose=verbose)

def _train_tensorflow_tfp(self):
def _train_tensorflow_tfp(self, verbose=1):
# There is only one optimization step. If using multiple steps with/without
# previous_optimizer_results, L-BFGS failed to reach a small error. The reason
# could be that tfp.optimizer.lbfgs_minimize will start from scratch for each
Expand All @@ -756,12 +761,12 @@ def _train_tensorflow_tfp(self):
n_iter += results.num_iterations.numpy()
self.train_state.epoch += results.num_iterations.numpy()
self.train_state.step += results.num_iterations.numpy()
self._test()
self._test(verbose=verbose)

if results.converged or results.failed:
break

def _train_pytorch_lbfgs(self):
def _train_pytorch_lbfgs(self, verbose=1):
prev_n_iter = 0
while prev_n_iter < optimizers.LBFGS_options["maxiter"]:
self.callbacks.on_epoch_begin()
Expand All @@ -784,15 +789,15 @@ def _train_pytorch_lbfgs(self):
self.train_state.epoch += n_iter - prev_n_iter
self.train_state.step += n_iter - prev_n_iter
prev_n_iter = n_iter
self._test()
self._test(verbose=verbose)

self.callbacks.on_batch_end()
self.callbacks.on_epoch_end()

if self.stop_training:
break

def _train_paddle_lbfgs(self):
def _train_paddle_lbfgs(self, verbose=1):
prev_n_iter = 0

while prev_n_iter < optimizers.LBFGS_options["maxiter"]:
Expand All @@ -816,15 +821,15 @@ def _train_paddle_lbfgs(self):
self.train_state.epoch += n_iter - prev_n_iter
self.train_state.step += n_iter - prev_n_iter
prev_n_iter = n_iter
self._test()
self._test(verbose=verbose)

self.callbacks.on_batch_end()
self.callbacks.on_epoch_end()

if self.stop_training:
break

def _test(self):
def _test(self, verbose=1):
# TODO Now only print the training loss in rank 0. The correct way is to print the average training loss of all ranks.
(
self.train_state.y_pred_train,
Expand Down Expand Up @@ -867,7 +872,7 @@ def _test(self):
or np.isnan(self.train_state.loss_test).any()
):
self.stop_training = True
if config.rank == 0:
if verbose > 0 and config.rank == 0:
display.training_display(self.train_state)

def predict(self, x, operator=None, callbacks=None):
Expand Down
3 changes: 2 additions & 1 deletion deepxde/utils/internal.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,8 @@ def wrapper(*args, **kwargs):
ts = timeit.default_timer()
result = f(*args, **kwargs)
te = timeit.default_timer()
if config.rank == 0:
verbose = kwargs.get('verbose', 1)
if verbose > 0 and config.rank == 0:
print("%r took %f s\n" % (f.__name__, te - ts))
sys.stdout.flush()
return result
Expand Down

0 comments on commit c816a11

Please sign in to comment.