Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Test loss #120

Open
raimis opened this issue Aug 29, 2022 · 2 comments
Open

Test loss #120

raimis opened this issue Aug 29, 2022 · 2 comments
Labels
question Further information is requested

Comments

@raimis
Copy link
Collaborator

raimis commented Aug 29, 2022

Which function is supposed to compute the test loss? Can we clean it up?

def validation_step(self, batch, batch_idx, *args):
if len(args) == 0 or (len(args) > 0 and args[0] == 0):
# validation step
return self.step(batch, mse_loss, "val")
# test step
return self.step(batch, l1_loss, "test")
def test_step(self, batch, batch_idx):
return self.step(batch, l1_loss, "test")

Ping: @PhilippThoelke @stefdoerr

@raimis raimis added the question Further information is requested label Aug 29, 2022
@giadefa
Copy link
Contributor

giadefa commented Aug 29, 2022 via email

@PhilippThoelke
Copy link
Collaborator

Loss is computed in step, no matter if during train, val or test. The function validation_step is called by pytorch-lightning to evaluate the model on all validation dataloaders. test_step is only called at the end of training or via trainer.test(...). Since we want to compute test loss during training, one of the validation dataloaders contains the test set.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
question Further information is requested
Projects
None yet
Development

No branches or pull requests

3 participants