Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Estimator return types #264

Merged
merged 40 commits into from
Feb 28, 2024
Merged
Show file tree
Hide file tree
Changes from 34 commits
Commits
Show all changes
40 commits
Select commit Hold shift + click to select a range
0eefd4b
Patsy interaction treatment terms can now be read
jmafoster1 Feb 7, 2024
c255ce0
refactor estimate_coefficent to only return pd.Series
christopher-wild Feb 13, 2024
4736391
Adapt unit tests to access series values for coefficents
christopher-wild Feb 13, 2024
1eac357
Fetch factors from Patsy to check types
christopher-wild Feb 13, 2024
f616091
Return float rather than Series for ci_high and ci_low
christopher-wild Feb 13, 2024
1e8ad89
Handle float and series confidence intervals
christopher-wild Feb 14, 2024
abe8599
Handle correct exception
christopher-wild Feb 14, 2024
767e5b4
More flexible handling due to multiple float types returns by estimators
christopher-wild Feb 14, 2024
2724d94
Handle series values
christopher-wild Feb 14, 2024
da87421
Linting
christopher-wild Feb 14, 2024
aca8bf9
Merge branch 'main' into interaction-terms
christopher-wild Feb 14, 2024
ea8c273
refactor all estimate_* return types to be pd.Series for LinearRegres…
christopher-wild Feb 14, 2024
c777503
Update return typings
christopher-wild Feb 14, 2024
fd7f79d
Refactor other estimator classes to return pd.Series
christopher-wild Feb 14, 2024
8f9499d
Extract bool from series
christopher-wild Feb 14, 2024
3bcefc7
Remove gen expression so elements can be indexed
christopher-wild Feb 14, 2024
e43bf38
All effects now expect pd.Series for the test values
christopher-wild Feb 14, 2024
ace2612
Update all unit tests to work with pd.Series refactor
christopher-wild Feb 14, 2024
c942537
Merge remote-tracking branch 'origin/interaction-terms' into interact…
christopher-wild Feb 14, 2024
18883d8
example_poisson_process.py now works with pd.Series refactor
christopher-wild Feb 14, 2024
8590dc9
Update typing
christopher-wild Feb 23, 2024
8e33b25
Update tests to use pd.Series for confidence intervals
christopher-wild Feb 23, 2024
9641319
Dictionary assertions use list CIs
christopher-wild Feb 23, 2024
35bae1f
SomeEffect and NoneEffect applys now work with pd.Series
christopher-wild Feb 23, 2024
6a5987a
_get_confidence_intervals method returns pd.Series
christopher-wild Feb 23, 2024
d0322ed
Merge branch 'main' into interaction-terms
christopher-wild Feb 23, 2024
33e7e53
Remove unnecessary unpacking of value
christopher-wild Feb 23, 2024
123d4db
tests represent the logic of returning Series better
christopher-wild Feb 23, 2024
17b8692
Pylint suggestions
christopher-wild Feb 23, 2024
4e06f34
Update surrogate code for new series return vals
rsomers1998 Feb 26, 2024
f66f854
Merge branch 'interaction-terms' of https://github.com/CITCOM-project…
rsomers1998 Feb 26, 2024
d742f74
Remove unused import
christopher-wild Feb 27, 2024
425329b
Update LR91 examples
christopher-wild Feb 27, 2024
c026f6a
Update example_beta.py
christopher-wild Feb 27, 2024
eb6bca6
Raise exception for Positive and Negative effect if multiple values p…
christopher-wild Feb 27, 2024
5265e9f
Fix typo in check for value length
christopher-wild Feb 27, 2024
fb287ec
Add test for catching multiple value exception
christopher-wild Feb 27, 2024
844849a
Use pandas inbuilt assert_series_equal test instead of casting everyt…
christopher-wild Feb 27, 2024
6029edb
Add limitation of single test_value to Effect docstrings
christopher-wild Feb 28, 2024
b8ad419
Ensure only single ate values are provided in surrogate_models
christopher-wild Feb 28, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion causal_testing/surrogate/surrogate_search_algorithms.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def fitness_function(ga, solution, idx): # pylint: disable=unused-argument

ate = surrogate.estimate_ate_calculated(adjustment_dict)

return contradiction_function(ate)
return contradiction_function(ate[0])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm a little worried about the scalability of this as it will only work for float values. If you have categorical data types, this will reduce to a binary treatment between the first and second values in alphabetical order. If you really want to only support float values for this, it would be good to have an explicit check for this to make sure that the user is getting what they expect to get.


gene_types, gene_space = self.create_gene_types(surrogate, specification)

Expand Down
47 changes: 21 additions & 26 deletions causal_testing/testing/causal_test_outcome.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,14 +27,13 @@ class SomeEffect(CausalTestOutcome):
"""An extension of TestOutcome representing that the expected causal effect should not be zero."""

def apply(self, res: CausalTestResult) -> bool:
if res.test_value.type == "ate":
return (0 < res.ci_low() < res.ci_high()) or (res.ci_low() < res.ci_high() < 0)
if res.test_value.type == "coefficient":
ci_low = res.ci_low() if isinstance(res.ci_low(), Iterable) else [res.ci_low()]
ci_high = res.ci_high() if isinstance(res.ci_high(), Iterable) else [res.ci_high()]
return any(0 < ci_low < ci_high or ci_low < ci_high < 0 for ci_low, ci_high in zip(ci_low, ci_high))
if res.test_value.type == "risk_ratio":
return (1 < res.ci_low() < res.ci_high()) or (res.ci_low() < res.ci_high() < 1)
return any(
1 < ci_low < ci_high or ci_low < ci_high < 1 for ci_low, ci_high in zip(res.ci_low(), res.ci_high()))
if res.test_value.type in ('coefficient', 'ate'):
return any(
0 < ci_low < ci_high or ci_low < ci_high < 0 for ci_low, ci_high in zip(res.ci_low(), res.ci_high()))

raise ValueError(f"Test Value type {res.test_value.type} is not valid for this TestOutcome")


Expand All @@ -51,23 +50,20 @@ def __init__(self, atol: float = 1e-10, ctol: float = 0.05):
self.ctol = ctol

def apply(self, res: CausalTestResult) -> bool:
if res.test_value.type == "ate":
return (res.ci_low() < 0 < res.ci_high()) or (abs(res.test_value.value) < self.atol)
if res.test_value.type == "coefficient":
ci_low = res.ci_low() if isinstance(res.ci_low(), Iterable) else [res.ci_low()]
ci_high = res.ci_high() if isinstance(res.ci_high(), Iterable) else [res.ci_high()]
if res.test_value.type == "risk_ratio":
return any(ci_low < 1 < ci_high or np.isclose(value, 1.0, atol=self.atol) for ci_low, ci_high, value in
zip(res.ci_low(), res.ci_high(), res.test_value.value))
if res.test_value.type in ('coefficient', 'ate'):
value = res.test_value.value if isinstance(res.ci_high(), Iterable) else [res.test_value.value]

return (
sum(
not ((ci_low < 0 < ci_high) or abs(v) < self.atol)
for ci_low, ci_high, v in zip(ci_low, ci_high, value)
)
/ len(value)
< self.ctol
sum(
not ((ci_low < 0 < ci_high) or abs(v) < self.atol)
for ci_low, ci_high, v in zip(res.ci_low(), res.ci_high(), value)
)
/ len(value)
< self.ctol
)
if res.test_value.type == "risk_ratio":
return (res.ci_low() < 1 < res.ci_high()) or np.isclose(res.test_value.value, 1.0, atol=self.atol)

raise ValueError(f"Test Value type {res.test_value.type} is not valid for this TestOutcome")


Expand Down Expand Up @@ -99,10 +95,9 @@ def apply(self, res: CausalTestResult) -> bool:
if res.ci_valid() and not super().apply(res):
return False
if res.test_value.type in {"ate", "coefficient"}:
return bool(res.test_value.value > 0)
return bool(res.test_value.value[0] > 0)
christopher-wild marked this conversation as resolved.
Show resolved Hide resolved
if res.test_value.type == "risk_ratio":
return bool(res.test_value.value > 1)
# Dead code but necessary for pylint
return bool(res.test_value.value[0] > 1)
christopher-wild marked this conversation as resolved.
Show resolved Hide resolved
raise ValueError(f"Test Value type {res.test_value.type} is not valid for this TestOutcome")


Expand All @@ -113,8 +108,8 @@ def apply(self, res: CausalTestResult) -> bool:
if res.ci_valid() and not super().apply(res):
return False
if res.test_value.type in {"ate", "coefficient"}:
return bool(res.test_value.value < 0)
return bool(res.test_value.value[0] < 0)
christopher-wild marked this conversation as resolved.
Show resolved Hide resolved
if res.test_value.type == "risk_ratio":
return bool(res.test_value.value < 1)
return bool(res.test_value.value[0] < 1)
christopher-wild marked this conversation as resolved.
Show resolved Hide resolved
# Dead code but necessary for pylint
raise ValueError(f"Test Value type {res.test_value.type} is not valid for this TestOutcome")
6 changes: 5 additions & 1 deletion causal_testing/testing/causal_test_result.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def __init__(
self,
estimator: Estimator,
test_value: TestValue,
confidence_intervals: [float, float] = None,
confidence_intervals: [pd.Series, pd.Series] = None,
christopher-wild marked this conversation as resolved.
Show resolved Hide resolved
effect_modifier_configuration: {Variable: Any} = None,
adequacy=None,
):
Expand Down Expand Up @@ -99,12 +99,16 @@ def to_dict(self, json=False):
def ci_low(self):
"""Return the lower bracket of the confidence intervals."""
if self.confidence_intervals:
if isinstance(self.confidence_intervals[0], pd.Series):
return self.confidence_intervals[0].to_list()
return self.confidence_intervals[0]
return None

def ci_high(self):
"""Return the higher bracket of the confidence intervals."""
if self.confidence_intervals:
if isinstance(self.confidence_intervals[1], pd.Series):
return self.confidence_intervals[1].to_list()
return self.confidence_intervals[1]
return None

Expand Down
61 changes: 29 additions & 32 deletions causal_testing/testing/estimators.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
import statsmodels.formula.api as smf
from econml.dml import CausalForestDML
from patsy import dmatrix # pylint: disable = no-name-in-module

from patsy import ModelDesc
from sklearn.ensemble import GradientBoostingRegressor
from statsmodels.regression.linear_model import RegressionResultsWrapper
from statsmodels.tools.sm_exceptions import PerfectSeparationError
Expand Down Expand Up @@ -343,30 +343,28 @@ def add_modelling_assumptions(self):
"do not need to be linear."
)

def estimate_coefficient(self) -> float:
def estimate_coefficient(self) -> tuple[pd.Series, list[pd.Series, pd.Series]]:
"""Estimate the unit average treatment effect of the treatment on the outcome. That is, the change in outcome
caused by a unit change in treatment.

:return: The unit average treatment effect and the 95% Wald confidence intervals.
"""
model = self._run_linear_regression()
newline = "\n"
treatment = [self.treatment]
if str(self.df.dtypes[self.treatment]) == "object":
patsy_md = ModelDesc.from_formula(self.treatment)
if any((self.df.dtypes[factor.name()] == 'object' for factor in patsy_md.rhs_termlist[1].factors)):
design_info = dmatrix(self.formula.split("~")[1], self.df).design_info
treatment = design_info.column_names[design_info.term_name_slices[self.treatment]]
else:
treatment = [self.treatment]
assert set(treatment).issubset(
model.params.index.tolist()
), f"{treatment} not in\n{' ' + str(model.params.index).replace(newline, newline + ' ')}"
unit_effect = model.params[treatment] # Unit effect is the coefficient of the treatment
[ci_low, ci_high] = self._get_confidence_intervals(model, treatment)
if str(self.df.dtypes[self.treatment]) != "object":
unit_effect = unit_effect[0]
ci_low = ci_low[0]
ci_high = ci_high[0]
return unit_effect, [ci_low, ci_high]

def estimate_ate(self) -> tuple[float, list[float, float], float]:
def estimate_ate(self) -> tuple[pd.Series, list[pd.Series, pd.Series]]:
"""Estimate the average treatment effect of the treatment on the outcome. That is, the change in outcome caused
by changing the treatment variable from the control value to the treatment value.

Expand All @@ -384,8 +382,9 @@ def estimate_ate(self) -> tuple[float, list[float, float], float]:

# Perform a t-test to compare the predicted outcome of the control and treated individual (ATE)
t_test_results = model.t_test(individuals.loc["treated"] - individuals.loc["control"])
ate = t_test_results.effect[0]
ate = pd.Series(t_test_results.effect[0])
confidence_intervals = list(t_test_results.conf_int(alpha=self.alpha).flatten())
confidence_intervals = [pd.Series(interval) for interval in confidence_intervals]
return ate, confidence_intervals

def estimate_control_treatment(self, adjustment_config: dict = None) -> tuple[pd.Series, pd.Series]:
Expand Down Expand Up @@ -414,7 +413,7 @@ def estimate_control_treatment(self, adjustment_config: dict = None) -> tuple[pd

return y.iloc[1], y.iloc[0]

def estimate_risk_ratio(self, adjustment_config: dict = None) -> tuple[float, list[float, float]]:
def estimate_risk_ratio(self, adjustment_config: dict = None) -> tuple[pd.Series, list[pd.Series, pd.Series]]:
"""Estimate the risk_ratio effect of the treatment on the outcome. That is, the change in outcome caused
by changing the treatment variable from the control value to the treatment value.

Expand All @@ -423,12 +422,11 @@ def estimate_risk_ratio(self, adjustment_config: dict = None) -> tuple[float, li
if adjustment_config is None:
adjustment_config = {}
control_outcome, treatment_outcome = self.estimate_control_treatment(adjustment_config=adjustment_config)
ci_low = treatment_outcome["mean_ci_lower"] / control_outcome["mean_ci_upper"]
ci_high = treatment_outcome["mean_ci_upper"] / control_outcome["mean_ci_lower"]

return (treatment_outcome["mean"] / control_outcome["mean"]), [ci_low, ci_high]
ci_low = pd.Series(treatment_outcome["mean_ci_lower"] / control_outcome["mean_ci_upper"])
ci_high = pd.Series(treatment_outcome["mean_ci_upper"] / control_outcome["mean_ci_lower"])
return pd.Series(treatment_outcome["mean"] / control_outcome["mean"]), [ci_low, ci_high]

def estimate_ate_calculated(self, adjustment_config: dict = None) -> tuple[float, list[float, float]]:
def estimate_ate_calculated(self, adjustment_config: dict = None) -> tuple[pd.Series, list[pd.Series, pd.Series]]:
"""Estimate the ate effect of the treatment on the outcome. That is, the change in outcome caused
by changing the treatment variable from the control value to the treatment value. Here, we actually
calculate the expected outcomes under control and treatment and divide one by the other. This
Expand All @@ -439,10 +437,9 @@ def estimate_ate_calculated(self, adjustment_config: dict = None) -> tuple[float
if adjustment_config is None:
adjustment_config = {}
control_outcome, treatment_outcome = self.estimate_control_treatment(adjustment_config=adjustment_config)
ci_low = treatment_outcome["mean_ci_lower"] - control_outcome["mean_ci_upper"]
ci_high = treatment_outcome["mean_ci_upper"] - control_outcome["mean_ci_lower"]

return (treatment_outcome["mean"] - control_outcome["mean"]), [ci_low, ci_high]
ci_low = pd.Series(treatment_outcome["mean_ci_lower"] - control_outcome["mean_ci_upper"])
ci_high = pd.Series(treatment_outcome["mean_ci_upper"] - control_outcome["mean_ci_lower"])
return pd.Series(treatment_outcome["mean"] - control_outcome["mean"]), [ci_low, ci_high]

def _run_linear_regression(self) -> RegressionResultsWrapper:
"""Run linear regression of the treatment and adjustment set against the outcome and return the model.
Expand All @@ -456,8 +453,8 @@ def _run_linear_regression(self) -> RegressionResultsWrapper:
def _get_confidence_intervals(self, model, treatment):
confidence_intervals = model.conf_int(alpha=self.alpha, cols=None)
ci_low, ci_high = (
confidence_intervals[0].loc[treatment],
confidence_intervals[1].loc[treatment],
pd.Series(confidence_intervals[0].loc[treatment]),
pd.Series(confidence_intervals[1].loc[treatment]),
)
return [ci_low, ci_high]

Expand Down Expand Up @@ -495,7 +492,7 @@ def __init__(
terms = [treatment] + sorted(list(adjustment_set)) + sorted(list(effect_modifiers))
self.formula = f"{outcome} ~ cr({'+'.join(terms)}, df={basis})"

def estimate_ate_calculated(self, adjustment_config: dict = None) -> float:
def estimate_ate_calculated(self, adjustment_config: dict = None) -> pd.Series:
model = self._run_linear_regression()

x = {"Intercept": 1, self.treatment: self.treatment_value}
Expand All @@ -511,7 +508,7 @@ def estimate_ate_calculated(self, adjustment_config: dict = None) -> float:
x[self.treatment] = self.control_value
control = model.predict(x).iloc[0]

return treatment - control
return pd.Series(treatment - control)


class InstrumentalVariableEstimator(Estimator):
Expand Down Expand Up @@ -567,7 +564,7 @@ def add_modelling_assumptions(self):
"""
)

def estimate_iv_coefficient(self, df):
def estimate_iv_coefficient(self, df) -> float:
"""
Estimate the linear regression coefficient of the treatment on the
outcome.
Expand All @@ -581,7 +578,7 @@ def estimate_iv_coefficient(self, df):
# Estimate the coefficient of I on X by cancelling
return ab / a

def estimate_coefficient(self, bootstrap_size=100):
def estimate_coefficient(self, bootstrap_size=100) -> tuple[pd.Series, list[pd.Series, pd.Series]]:
"""
Estimate the unit ate (i.e. coefficient) of the treatment on the
outcome.
Expand All @@ -590,10 +587,10 @@ def estimate_coefficient(self, bootstrap_size=100):
[self.estimate_iv_coefficient(self.df.sample(len(self.df), replace=True)) for _ in range(bootstrap_size)]
)
bound = ceil((bootstrap_size * self.alpha) / 2)
ci_low = bootstraps[bound]
ci_high = bootstraps[bootstrap_size - bound]
ci_low = pd.Series(bootstraps[bound])
ci_high = pd.Series(bootstraps[bootstrap_size - bound])

return self.estimate_iv_coefficient(self.df), (ci_low, ci_high)
return pd.Series(self.estimate_iv_coefficient(self.df)), [ci_low, ci_high]


class CausalForestEstimator(Estimator):
Expand All @@ -610,7 +607,7 @@ def add_modelling_assumptions(self):
"""
self.modelling_assumptions.append("Non-parametric estimator: no restrictions imposed on the data.")

def estimate_ate(self) -> float:
def estimate_ate(self) -> tuple[pd.Series, list[pd.Series, pd.Series]]:
"""Estimate the average treatment effect.

:return ate, confidence_intervals: The average treatment effect and 95% confidence intervals.
Expand Down Expand Up @@ -638,9 +635,9 @@ def estimate_ate(self) -> float:
model.fit(outcome_df, treatment_df, X=effect_modifier_df, W=confounders_df)

# Obtain the ATE and 95% confidence intervals
ate = model.ate(effect_modifier_df, T0=self.control_value, T1=self.treatment_value)
ate = pd.Series(model.ate(effect_modifier_df, T0=self.control_value, T1=self.treatment_value))
ate_interval = model.ate_interval(effect_modifier_df, T0=self.control_value, T1=self.treatment_value)
ci_low, ci_high = ate_interval[0], ate_interval[1]
ci_low, ci_high = pd.Series(ate_interval[0]), pd.Series(ate_interval[1])
return ate, [ci_low, ci_high]

def estimate_cates(self) -> pd.DataFrame:
Expand Down
13 changes: 6 additions & 7 deletions examples/covasim_/doubling_beta/example_beta.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,8 +276,8 @@ def setup(observational_data):

def plot_doubling_beta_CATEs(results_dict, title, figure=None, axes=None, row=None, col=None):
# Get the CATE as a percentage for association and causation
ate = results_dict["causation"]["ate"]
association_ate = results_dict["association"]["ate"]
ate = results_dict["causation"]["ate"][0]
association_ate = results_dict["association"]["ate"][0]

causation_df = results_dict["causation"]["df"]
association_df = results_dict["association"]["df"]
Expand All @@ -288,11 +288,10 @@ def plot_doubling_beta_CATEs(results_dict, title, figure=None, axes=None, row=No
# Get 95% confidence intervals for association and causation
ate_cis = results_dict["causation"]["cis"]
association_ate_cis = results_dict["association"]["cis"]
percentage_causal_ate_cis = [round(((ci / causation_df["cum_infections"].mean()) * 100), 3) for ci in ate_cis]
percentage_causal_ate_cis = [round(((ci[0] / causation_df["cum_infections"].mean()) * 100), 3) for ci in ate_cis]
percentage_association_ate_cis = [
round(((ci / association_df["cum_infections"].mean()) * 100), 3) for ci in association_ate_cis
round(((ci[0] / association_df["cum_infections"].mean()) * 100), 3) for ci in association_ate_cis
]

# Convert confidence intervals to errors for plotting
percentage_causal_errs = [
percentage_ate - percentage_causal_ate_cis[0],
Expand All @@ -314,9 +313,9 @@ def plot_doubling_beta_CATEs(results_dict, title, figure=None, axes=None, row=No
if "counterfactual" in results_dict.keys():
cf_ate = results_dict["counterfactual"]["ate"]
cf_df = results_dict["counterfactual"]["df"]
percentage_cf_ate = round((cf_ate / cf_df["cum_infections"].mean()) * 100, 3)
percentage_cf_ate = round((cf_ate[0] / cf_df["cum_infections"].mean()) * 100, 3)
cf_ate_cis = results_dict["counterfactual"]["cis"]
percentage_cf_cis = [round(((ci / cf_df["cum_infections"].mean()) * 100), 3) for ci in cf_ate_cis]
percentage_cf_cis = [round(((ci[0] / cf_df["cum_infections"].mean()) * 100), 3) for ci in cf_ate_cis]
percentage_cf_errs = [percentage_cf_ate - percentage_cf_cis[0], percentage_cf_cis[1] - percentage_cf_ate]
xs = [0.5, 1.5, 2.5]
ys = [association_percentage_ate, percentage_ate, percentage_cf_ate]
Expand Down
4 changes: 2 additions & 2 deletions examples/lr91/example_max_conductances.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,8 +164,8 @@ def plot_ates_with_cis(results_dict: dict, xs: list, save: bool = False, show: b
before_underscore, after_underscore = treatment.split("_")
after_underscore_braces = f"{{{after_underscore}}}"
latex_compatible_treatment_str = rf"${before_underscore}_{after_underscore_braces}$"
cis_low = [c[0] for c in cis]
cis_high = [c[1] for c in cis]
cis_low = [c[0][0] for c in cis]
cis_high = [c[1][0] for c in cis]
axes.fill_between(
xs, cis_low, cis_high, alpha=0.2, color=input_colors[treatment], label=latex_compatible_treatment_str
)
Expand Down
4 changes: 2 additions & 2 deletions examples/lr91/example_max_conductances_test_suite.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,8 +166,8 @@ def plot_ates_with_cis(results_dict: dict, xs: list, save: bool = False, show=Fa
before_underscore, after_underscore = treatment.split("_")
after_underscore_braces = f"{{{after_underscore}}}"
latex_compatible_treatment_str = rf"${before_underscore}_{after_underscore_braces}$"
cis_low = [c[0] for c in cis]
cis_high = [c[1] for c in cis]
cis_low = [c[0][0] for c in cis]
cis_high = [c[1][0] for c in cis]
axes.fill_between(
xs, cis_low, cis_high, alpha=0.2, color=input_colors[treatment], label=latex_compatible_treatment_str
)
Expand Down
4 changes: 2 additions & 2 deletions examples/poisson-line-process/example_poisson_process.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,8 +198,8 @@ def test_poisson_width_num_shapes(save=False):
"treatment": treatment_value,
"intensity": i,
"ate": causal_test_result.test_value.value,
"ci_low": min(causal_test_result.confidence_intervals),
"ci_high": max(causal_test_result.confidence_intervals),
"ci_low": causal_test_result.confidence_intervals[0][0],
"ci_high": causal_test_result.confidence_intervals[1][0],
}
width_num_shapes_results.append(results)
width_num_shapes_results = pd.DataFrame(width_num_shapes_results)
Expand Down
Loading
Loading