Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added a functionality to take predefined tqdm loops as input to solver.fit() #221

Open
wants to merge 9 commits into
base: master
Choose a base branch
from
6 changes: 3 additions & 3 deletions neurodiffeq/generators.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,20 +7,20 @@

def _chebyshev_first(a, b, n):
nodes = torch.cos(((torch.arange(n) + 0.5) / n) * np.pi)
nodes = ((a + b) + (b - a) * nodes) / 2
nodes = ((a + b) + (a - b) * nodes) / 2
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this change meant to ensure the returned points are increasing?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, that's the reason

nodes.requires_grad_(True)
return nodes


def _chebyshev_second(a, b, n):
nodes = torch.cos(torch.arange(n) / float(n - 1) * np.pi)
nodes = ((a + b) + (b - a) * nodes) / 2
nodes = ((a + b) + (a - b) * nodes) / 2
nodes.requires_grad_(True)
return nodes

def _chebyshev_second_noisy(a, b, n):
nodes = torch.cos((torch.arange(n) + (torch.rand(n) * 2 - 1)) / float(n - 1) * np.pi)
nodes = ((a + b) + (b - a) * nodes) / 2
nodes = ((a + b) + (a - b) * nodes) / 2
nodes.requires_grad_(True)
return nodes

Expand Down
20 changes: 12 additions & 8 deletions neurodiffeq/solvers.py
Original file line number Diff line number Diff line change
Expand Up @@ -440,7 +440,7 @@ def _update_best(self, key):
self.lowest_loss = current_loss
self.best_nets = deepcopy(self.nets)

def fit(self, max_epochs, callbacks=(), tqdm_file=sys.stderr, **kwargs):
def fit(self, max_epochs, callbacks=(), tqdm_file='default', **kwargs):
r"""Run multiple epochs of training and validation, update best loss at the end of each epoch.

If ``callbacks`` is passed, callbacks are run, one at a time,
Expand Down Expand Up @@ -471,19 +471,21 @@ def fit(self, max_epochs, callbacks=(), tqdm_file=sys.stderr, **kwargs):
callbacks = [monitor.to_callback()] + list(callbacks)
if kwargs:
raise ValueError(f'Unknown keyword argument(s): {list(kwargs.keys())}') # pragma: no cover

if tqdm_file is None:
loop = range(max_epochs)
else:

flag = True
if 'default' in str(tqdm_file):
loop = tqdm(
range(max_epochs),
total = max_epochs,
desc='Training Progress',
colour='blue',
file=tqdm_file,
dynamic_ncols=True,
)
elif tqdm_file is not None:
loop = tqdm_file
else:
flag = False

for local_epoch in loop:
for local_epoch in range(max_epochs):
# stop training if self._stop_training is set to True by a callback
if self._stop_training:
break
Expand All @@ -495,6 +497,8 @@ def fit(self, max_epochs, callbacks=(), tqdm_file=sys.stderr, **kwargs):

for cb in callbacks:
cb(self)
if flag:
loop.update(1)

@abstractmethod
def get_solution(self, copy=True, best=True):
Expand Down