Skip to content

Commit 48f0ae1

Browse files
committed
Adjust AdversarialBalancing to numpy>=2
replace `row_stack` with `vstack` and `np.NaN` with `np.nan`. Signed-off-by: Ehud-Karavani <[email protected]>
1 parent 3551319 commit 48f0ae1

File tree

2 files changed

+3
-3
lines changed

2 files changed

+3
-3
lines changed

causallib/contrib/adversarial_balancing/adversarial_balancing.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -137,7 +137,7 @@ def _run(self, X, A, w_init=None, is_train=True, use_stabilized=None, **select_k
137137
if not np.all(unique_treatments == np.arange(n_treatments)):
138138
raise AssertionError("Treatment values in `a` must be indexed 0, 1, 2, ...")
139139
self.iterative_models_ = np.empty((n_treatments, self.iterations), dtype=object)
140-
self.iterative_normalizing_consts_ = np.full((n_treatments, self.iterations), np.NaN)
140+
self.iterative_normalizing_consts_ = np.full((n_treatments, self.iterations), np.nan)
141141

142142
self.discriminator_loss_ = np.zeros((n_treatments, self.iterations))
143143
self.treatments_frequency_ = _compute_treatments_frequency(A)
@@ -147,7 +147,7 @@ def _run(self, X, A, w_init=None, is_train=True, use_stabilized=None, **select_k
147147
# population ("source population"),
148148
# and the samples with label -1 are the population under treatment a ("target population").
149149
# Labels 1 and -1 (rather than 0) are used because of the later exponential loss function
150-
X_augm = np.row_stack((X, X[A == a])) # create the augmented dataset
150+
X_augm = np.vstack((X, X[A == a])) # create the augmented dataset
151151
y = np.ones((X_augm.shape[0]))
152152
y[X.shape[0]:] *= -1 # subpopulation of current treatment (a) has y== -1
153153
target_pop_mask = y == -1

causallib/contrib/tests/test_adversarial_balancing.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ class TestAdversarialBalancing(unittest.TestCase):
1616
def create_identical_treatment_groups_data(n=100):
1717
np.random.seed(42)
1818
X = np.random.rand(n, 3)
19-
X = np.row_stack((X, X)) # Duplicate identical samples
19+
X = np.vstack((X, X)) # Duplicate identical samples
2020
a = np.array([1] * n + [0] * n) # Give duplicated samples different treatment assignment
2121
X, a = pd.DataFrame(X), pd.Series(a)
2222
return X, a

0 commit comments

Comments
 (0)