-
Notifications
You must be signed in to change notification settings - Fork 755
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Towards SoftAdapt loss balancing for tf.compat.v1 #1586
base: master
Are you sure you want to change the base?
Conversation
deepxde/callbacks.py
Outdated
loss_weights = dde.Variable(loss_weights, trainable=False, dtype=loss_weights.dtype) | ||
loss_weights *= 0 | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am trying to allow loss_weights
to be Variable
, such that the loss function updates automatically every time that the weights change. Any clue @lululxvi ?
Here, I was trying to set the loss_weights
to 0. Therefore, the loss
shall give 0 for next epochs (which is not the case so far).
Shall we define loss_weights
differently in model.compile
?
Maybe we need to work here in:
Lines 169 to 183 in 3b08fe3
def losses(losses_fn): | |
# Data losses | |
losses = losses_fn( | |
self.net.targets, self.net.outputs, loss_fn, self.net.inputs, self | |
) | |
if not isinstance(losses, list): | |
losses = [losses] | |
# Regularization loss | |
if self.net.regularizer is not None: | |
losses.append(tf.losses.get_regularization_loss()) | |
losses = tf.convert_to_tensor(losses) | |
# Weighted losses | |
if loss_weights is not None: | |
losses *= loss_weights | |
return losses |
Thank you!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How do you plan to update loss_weights
?
Implementing in TensorFlow is tricky, as it is static graph. It should be much easier to implement in pytorch, where you can directly change the |
Thank you for your feeback. I would really prefer to implement this adaptive loss callback in I think I'll start with a simple two-terms loss (and one weighing parameter). |
|
Hi, if we define Next, we have to define appopriately the total_loss. |
|
Hi @pescap @lululxvi @haison19952013 I have recently been working on the adaptive weights, and here's what worked for me (I think). from deepxde.backend import tf, Variable
class SoftAdapt(dde.callbacks.Callback):
"""Use adaptive loss balancing.
Args:
beta: If beta > 0, then softAdapt will pay more attention the worst performing
loss component. If beta < 0, then SoftAdapt will assign higher weights
to the better performing components. Beta==0 is the trivial case and
all loss components will have coefficient 1.
epsilon: parameter to prevent overflows.
"""
def __init__(self, beta=0.1, epsilon=1e-8,losshistory=None,lr=0.01):
super().__init__()
self.beta = beta
self.epsilon = epsilon
self.losshistory=losshistory
# self.epoch = 0
self.lr = lr
def on_epoch_end(self):
loss_weights = self.losshistory.loss_weights
weight_bc = loss_weights[3]
weight_data = loss_weights[-1]
current_loss = self.losshistory.loss_train[-1]
pde_avg = current_loss[0:3].mean()
bc_avg = current_loss[3:11].mean()
data_avg = current_loss[11::].mean()
weight_bc = (1-self.lr) * weight_bc + self.lr * current_loss[0:3].max() / bc_avg
weight_data = (1 - self.lr) * weight_data + self.lr * current_loss[0:3].max() / data_avg
loss_weights = [1] * 3 + [weight_bc] * 8 + [weight_data] * 3
self.losshistory.set_loss_weights(loss_weights)
print(self.losshistory.loss_weights, "loss_weights") |
What is the backend you have tested? |
I've used tensorflow.v1 compact for this code |
I don't think the code would change the weights during training. For TF v1, if you don't recompile, it always uses the original computational graph. |
Work in progress!