Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* add: prototype implementation of tarp in sbi * fix: wrong use of torch.nn loss functions * fix: wrong use of min values per dim * introduced overconfident / underdispersed samples * fix: wrong generation of toy gaussian data * add simple test to check detection of pathological cases * added biased case * separate quite long test file * prepared test case with trained NPE - does not work yet * first draft of TARP implementation for now * formatting code according to ruff * refactoring all class based methods into free functions * removed print statements from tests * removed TARP class * removed tests for TARP class * renamed tests * added kstest and difference of area under curve - these checks will help practitioners to identify malformed posteriors - added tests for these checks * refactored metrics into sbi.utils.metrics * prefer optional over Union[type,None] * refer to num_ instead of n_ * asserting shapes early on * remove obsolete assert * refactored references check into its own function * ruff reformatting * check method updated, plot method added - reformatted code - added plotting method to make it easier to visualize TARP - rewrote ATC property of ECP plot to indicate into which direction the posterior is shifted or dispersed - updated tests accordingly * removed superflous shape checks as best as possible * refactored run_tarp into function similar to run_sbc * refactored check_references - added docstring - removed spurious variables * pyright fixes for tarp code * pyright fixes to correct parameter name * Apply suggestions from code review * Apply suggestions from code review: line length * Update tests/tarp_test.py make slow test pass. * Apply suggestions from code review * fix last pyright issues * refactoring --------- Co-authored-by: Jan <[email protected]> Co-authored-by: janfb <[email protected]>
- Loading branch information