Skip to content

Commit

Permalink
Add **kwargs to covariate_imbalance_count_error (#64)
Browse files Browse the repository at this point in the history
* Add **kwargs to covariate_imbalance_count_error

Other error functions use **kwargs to maintain flexibility about calling API, this function was missing it.

* Add test for non-existing kwargs in `count_imbalance` metric

Signed-off-by: Ehud-Karavani <[email protected]>

* Bump CodeClimate Github Action version to v5.0.0

Signed-off-by: Ehud-Karavani <[email protected]>

---------

Signed-off-by: Ehud-Karavani <[email protected]>
Co-authored-by: ehudkr <[email protected]>
Co-authored-by: Ehud-Karavani <[email protected]>
  • Loading branch information
3 people authored Oct 25, 2023
1 parent be8276b commit f871ac1
Show file tree
Hide file tree
Showing 3 changed files with 9 additions and 2 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ jobs:
# pytest tests.py --doctest-modules --junitxml=junit/test-results.xml --cov=com --cov-report=xml --cov-report=html
- name: Publish to CodeClimate
uses: paambaati/codeclimate-action@v3.2.0
uses: paambaati/codeclimate-action@v5.0.0
env:
CC_TEST_REPORTER_ID: ${{ secrets.CODECLIMATE_REPORTER_ID }}

2 changes: 1 addition & 1 deletion causallib/metrics/weight_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ def covariate_balancing_error(X, a, sample_weight, agg=max, **kwargs):


def covariate_imbalance_count_error(
X, a, sample_weight, threshold=0.1, fraction=True
X, a, sample_weight, threshold=0.1, fraction=True, **kwargs
) -> float:
asmds = calculate_covariate_balance(X, a, sample_weight, metric="abs_smd")
weighted_asmds = asmds["weighted"]
Expand Down
7 changes: 7 additions & 0 deletions causallib/tests/test_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,6 +298,13 @@ def test_covariate_imbalance_count(self):
)
self.assertEqual(score, 1/2)

with self.subTest("Doesn't fail on unrelated kwargs"):
covariate_imbalance_count_error(
self.data["X"], self.data["a"], self.data["w"],
nonexistingkwarg=1,
)
self.assertTrue(True)


class TestOutcomeMetrics(unittest.TestCase):
def test_balanced_residuals(self):
Expand Down

0 comments on commit f871ac1

Please sign in to comment.