-
Notifications
You must be signed in to change notification settings - Fork 96
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Pytorch implementation #1
base: dev/pytorch
Are you sure you want to change the base?
Pytorch implementation #1
Conversation
- With Pytorch's StudentT and the TF/manual torch implementations I'm able to get probabilities > 1 with data generated from a contrived linear regression that the model can solve. Not sure why
First of all, thank you for contributing the new pytorch implementation! We really appreciate this initiative on your part. I would like to take some time to review your updates before merging in to master / PyPi. I can provide some comments on the PR and (if it's okay with you) commit some suggestions to your branch before merging. In the meantime, I believe your code will serve as a great starting point for others who would like to try the method in pytorch. |
Hi! I'd actually suggest you make a pytorch dev branch and I'll re-PR into that. I always wonder why that isn't an option on github since it's more logical. That way I can also revert all of the neurips2020 import refactoring and leave all of the original code untouched. Please do give suggestions - I'm especially not happy at the moment with how the two implementations handle the namespace. I think there's potentially a much cleaner way (for the user) to go about this and it's definitely not ready to merge into main as is. |
Good point, I just created the pytorch dev branch ( I agree with the namespace issue. I think one way to cleanly handle this is like how keras used to handle multiple backends and read the backend from an OS variable (doc). Similar approach is adopted by pyrender. Alternatively, we could adopt an approach similar to how matplotlib works (they have a base method that allows to switch backends). |
This reverts commit 24cc9aa.
This reverts commit f69c59f.
This reverts commit d973b95.
Automatically detects torch and tf availability - error when neither is available - when only one is available, allows only that backend - when both are available, default to tf backend - set_backend('tf'|'torch) manipulates edl.loss and edl.layers namespace
I think the next steps would be first validating the pytorch code and then finding a cleaner way to handle the namespace. |
- from https://github.com/dougbrion/pytorch-classification-uncertainty - MIT license - Dirichlet dist. with log-loss - No KL divergence annealing
Dirichlet UQ for discrete classification is implemented and validated @benjiachong |
Hi, @Dariusrussellkish. Did you solve this problem in pytorch implementation? Thank you. |
Hi all, great talk and paper.
I did the preliminary work of porting this to PyTorch. There are a few niceties that could be further implemented like specifying batch dimension and some customization with reduction, and how huggingface/transformers has both tf and torch implementations without requiring both as dependencies.
Otherwise it's all there for NIG. I didn't implement the Dirichlet_SOS loss since it wasn't clear where it would be used. I'll work on porting the NeurIPS examples but since that will take a while, I figured it would be useful to give the base torch code for now.
Of note: I found some numerical instabilities/issues with the student t distribution when the model has very confident regressions. Just to test the torch version, I did SGD on a simple contrived linear regression and found that as the model achieved a strong fit, its probabilities via student t went > 1 (and nll went < 0). It obviously doesn't hinder the training, but it seems a bit off to have a calculation produce probabilities > 1.
Pytorch has an implementation of StudentT, which also suffers from the same instabilities but are numerically different. I went with directly porting the TF code for numerical consistency.
Something like the below will recreate this instability.