From f871ac1af3ee042e4ca11f274106a2e8b1c2cef5 Mon Sep 17 00:00:00 2001 From: mmdanziger Date: Wed, 25 Oct 2023 10:29:24 +0300 Subject: [PATCH] Add **kwargs to covariate_imbalance_count_error (#64) * Add **kwargs to covariate_imbalance_count_error Other error functions use **kwargs to maintain flexibility about calling API, this function was missing it. * Add test for non-existing kwargs in `count_imbalance` metric Signed-off-by: Ehud-Karavani * Bump CodeClimate Github Action version to v5.0.0 Signed-off-by: Ehud-Karavani --------- Signed-off-by: Ehud-Karavani Co-authored-by: ehudkr Co-authored-by: Ehud-Karavani --- .github/workflows/build.yml | 2 +- causallib/metrics/weight_metrics.py | 2 +- causallib/tests/test_metrics.py | 7 +++++++ 3 files changed, 9 insertions(+), 2 deletions(-) diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 6723a59..cb0c64b 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -41,7 +41,7 @@ jobs: # pytest tests.py --doctest-modules --junitxml=junit/test-results.xml --cov=com --cov-report=xml --cov-report=html - name: Publish to CodeClimate - uses: paambaati/codeclimate-action@v3.2.0 + uses: paambaati/codeclimate-action@v5.0.0 env: CC_TEST_REPORTER_ID: ${{ secrets.CODECLIMATE_REPORTER_ID }} diff --git a/causallib/metrics/weight_metrics.py b/causallib/metrics/weight_metrics.py index e151bce..a190427 100644 --- a/causallib/metrics/weight_metrics.py +++ b/causallib/metrics/weight_metrics.py @@ -122,7 +122,7 @@ def covariate_balancing_error(X, a, sample_weight, agg=max, **kwargs): def covariate_imbalance_count_error( - X, a, sample_weight, threshold=0.1, fraction=True + X, a, sample_weight, threshold=0.1, fraction=True, **kwargs ) -> float: asmds = calculate_covariate_balance(X, a, sample_weight, metric="abs_smd") weighted_asmds = asmds["weighted"] diff --git a/causallib/tests/test_metrics.py b/causallib/tests/test_metrics.py index 2a57418..b26051c 100644 --- a/causallib/tests/test_metrics.py +++ b/causallib/tests/test_metrics.py @@ -298,6 +298,13 @@ def test_covariate_imbalance_count(self): ) self.assertEqual(score, 1/2) + with self.subTest("Doesn't fail on unrelated kwargs"): + covariate_imbalance_count_error( + self.data["X"], self.data["a"], self.data["w"], + nonexistingkwarg=1, + ) + self.assertTrue(True) + class TestOutcomeMetrics(unittest.TestCase): def test_balanced_residuals(self):