Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Questions about the computation of the survival layer in MCAT_Surv #8

Open
huangmozhilv opened this issue Sep 20, 2022 · 1 comment

Comments

@huangmozhilv
Copy link

huangmozhilv commented Sep 20, 2022

It seems that the computation of the survival layer in MCAT_Surv(link) is wrong, and logits = self.classifier(h).unsqueeze(0) should be logits = self.classifier(h). With the old version, supposing that the batch_size=6 and n_classes=4, the logits will be of size of (1,6,4), the hazards will be of size of (1,6,4), the Y_hat will be of size of (1,1,4), which certainly does not contain the Y_hat for the 6 samples of the batch. Besides, the S will means the cumulative production of the survival(i.e. 1-hazards) along the batch dimension, what does this mean? This S is of size of (1,6,4), then the len(S) in CoxSurvLoss(link) will be 1, which certainly is not the batch size as expected.

In the end, could you provide the reference of the equations for you to write this cox loss?

@huangmozhilv huangmozhilv changed the title Typo in computing Y_hat? Questions about the computation of the survival layer in MCAT_Surv Sep 21, 2022
@huangmozhilv
Copy link
Author

I figure out the reason. The loss you used was not the same as the original cox loss and was adapted to be used for batch_size=1.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant