Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

CompileLoss truncates multiple loss inputs #19855

Open
Darkdragon84 opened this issue Jun 14, 2024 · 3 comments · May be fixed by #19879
Open

CompileLoss truncates multiple loss inputs #19855

Darkdragon84 opened this issue Jun 14, 2024 · 3 comments · May be fixed by #19879
Assignees
Labels

Comments

@Darkdragon84
Copy link

Darkdragon84 commented Jun 14, 2024

Environment

System: Apple M2 MacOS Sonoma 14.4.1
python: 3.11.9
tensorflow: 2.16.1
keras: 3.3.3

Issue

Upon first call of model.compute_loss, CompileLoss.call is triggered, which fails if y_true is e.g. a tuple (y1, y2) and the loss function expects y_true: tuple[tf.Tensor, tf.Tensor], as only the first element y1 gets passed to the loss function.

Details

let's say I have a loss function

def my_loss(y_true: tuple[tf.Tensor, tf.Tensor], y_pred: tf.Tensor)

i.e. the ground truth consists of two tensors (I can't stack them, their size is incompatible) and the prediction from my model is a single tensor.

Let's also say I have a model that takes two inputs x1 and x2 and produces one output y_pred.

Due to the nature of my problem, the input is automatically the ground truth, i.e. y_true=(x1, x2).

So I would like to execute

# [...] model and data creation
model.compile(loss=my_loss)
y_pred = model((x1, x2))
loss = model.compute_loss(y=(x1, x2), y_pred=y_pred)

where the last call causes an exception in my_loss, because only x1 gets passed as y_true to my_loss.

I have pinpointed the problem to the call to y_true = self._flatten_y(y_true) here and the subsequent zip iteration here

The problem seems to be that the zip iteration expects y_true and y_pred to be iterables of the same length as self.flat_losses (which in my case is just 1). Now y_pred = self._flatten_y(y_pred) wraps the single tensor y_pred into a single element list, which is correct. However, y_true = self._flatten_y(y_true) converts y_true from a 2 element tuple to a 2 element list, where it should be a nested list with a single 2-element list [[y1, y2]].

Consequently the zip iteration takes only y1 from y_true in its single iteration (since all other iterables are just length one) and passes it as y_true argument to my_loss.

I imagine this behavior comes from the fact that in cases where one has multiple losses (which take single tensors as y_true and y_pred), the correct way of calling compute_loss is to pass sequences for y_true and y_pred to compute_loss, one for each loss function.

Isn't there a way to reconcile both cases? I.e. sequence inputs to single loss functions, but still supporting multiple loss functions? All the information is there, i.e. how many loss functions and what are their signatures. For starters, it is not checked if all elements in the zip iteration are of same length...

Reproducible Example

Here is a reproducible minimal example

import numpy as np
import tensorflow as tf
import tensorflow.keras as keras


def my_loss(y_true: tuple[tf.Tensor, tf.Tensor], y_pred: tf.Tensor):
    y1, y2 = y_true
    pred_sum = keras.ops.sum(y_pred)
    return keras.ops.abs(keras.ops.sum(y1) - pred_sum) + keras.ops.abs(keras.ops.sum(y2) - pred_sum)


def main():
    input1 = keras.Input((2,))
    input2 = keras.Input((3, 6))
    x1 = keras.ops.expand_dims(keras.layers.Dense(10)(input1), 1)
    x2 = keras.layers.Dense(10)(input2)
    x = keras.ops.sum(x1 + x2, axis=1)
    out = keras.layers.Dense(8)(x)
    model = keras.Model(inputs=[input1, input2], outputs=out)
    model.compile(loss=my_loss)

    x1 = tf.random.uniform((10, 2))
    x2 = tf.random.uniform((10, 3, 6))
    y_pred = model((x1, x2))

    loss1 = my_loss((x1, x2), y_pred)
    loss2 = model.compute_loss(y=(x1, x2), y_pred=y_pred)

    print(loss1)
    print(loss2)


if __name__ == "__main__":
    main()

which produces the following error

/Users/valentin/miniconda3/envs/dif/bin/p
ython /Users/valentin/parity/python/diffusion/src/discrete_diffusion/loss_problem.py
Traceback (most recent call last):
  File "/Users/valentin/parity/python/diffusion/src/discrete_diffusion/loss_problem.py", line 34, in <module>
    main()
  File "/Users/valentin/parity/python/diffusion/src/discrete_diffusion/loss_problem.py", line 27, in main
    loss2 = model.compute_loss(y=(x1, x2), y_pred=y_pred)
            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/valentin/miniconda3/envs/dif/lib/python3.11/site-packages/keras/src/trainers/trainer.py", line 316, in compute_loss
    loss = self._compile_loss(y, y_pred, sample_weight)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/valentin/miniconda3/envs/dif/lib/python3.11/site-packages/keras/src/trainers/compile_utils.py", line 609, in __call__
    return self.call(y_true, y_pred, sample_weight)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/valentin/miniconda3/envs/dif/lib/python3.11/site-packages/keras/src/trainers/compile_utils.py", line 645, in call
    loss(y_t, y_p, sample_weight), dtype=backend.floatx()
    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/valentin/miniconda3/envs/dif/lib/python3.11/site-packages/keras/src/losses/loss.py", line 43, in __call__
    losses = self.call(y_true, y_pred)
             ^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/valentin/miniconda3/envs/dif/lib/python3.11/site-packages/keras/src/losses/losses.py", line 22, in call
    return self.fn(y_true, y_pred, **self._fn_kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/valentin/parity/python/diffusion/src/discrete_diffusion/loss_problem.py", line 7, in my_loss
    y1, y2 = y_true
    ^^^^^^

Many thanks for all the great effort on keras, I greatly appreciate all the awesome features!

@Darkdragon84
Copy link
Author

The same problem persists if you use dict inputs btw.

def my_loss(y_true: dict[str, tf.Tensor], y_pred: tf.Tensor):
    y1 = y_true["x1"]
    y2 = y_true["x2"]
    pred_sum = keras.ops.sum(y_pred)
    return keras.ops.abs(keras.ops.sum(y1) - pred_sum) + keras.ops.abs(keras.ops.sum(y2) - pred_sum)
...
model.compute_loss(y={"x1": x1, "x2": x2}, y_pred=y_pred)
...
...

There it's even worse, as the dict gets converted to a (sorted) list by y_true = self._flatten_y(y_true). So if you coded your loss expecting a dict input, it now receives a list, which will again lead to problems.

@james77777778
Copy link
Contributor

Hi @Darkdragon84

Happy to hear your thoughts on #19879

With that PR and a few small changes, your code works:

import keras


class MyLoss(keras.Loss):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

    def call(self, y_true, y_pred):
        y1, y2 = y_true
        pred_sum = keras.ops.sum(y_pred)
        return keras.ops.abs(keras.ops.sum(y1) - pred_sum) + keras.ops.abs(
            keras.ops.sum(y2) - pred_sum
        )


def main():
    input1 = keras.Input((2,))
    input2 = keras.Input((3, 6))
    x1 = keras.ops.expand_dims(keras.layers.Dense(10)(input1), 1)
    x2 = keras.layers.Dense(10)(input2)
    x = keras.ops.sum(x1 + x2, axis=1)
    out = keras.layers.Dense(8)(x)
    model = keras.Model(inputs=[input1, input2], outputs=out)
    my_loss = MyLoss()
    my_loss.set_specs([input1, input2], out)  # <-- newly introduced feature
    model.compile(loss=my_loss)

    x1 = keras.random.uniform((10, 2))
    x2 = keras.random.uniform((10, 3, 6))
    y_pred = model((x1, x2))

    loss1 = my_loss((x1, x2), y_pred)
    print(loss1)
    loss2 = model.compute_loss(y=(x1, x2), y_pred=y_pred)
    print(loss2)


if __name__ == "__main__":
    main()
tf.Tensor(68.81332, shape=(), dtype=float32)
tf.Tensor(68.81332, shape=(), dtype=float32)

@Darkdragon84
Copy link
Author

Hi @james77777778

Wow, fantastic, that PR looks great. The changes are way above my understanding of the inner workings of Keras, but I'll sure have a look! Thanks a lot for the quick response in form of a fix PR 👏

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants