This repository contains the implementation for the paper "Balancing Interference and Correlation in Spatial Experimental Designs: A Causal Graph Cut Approach" (ICML 2025) in Python.
This paper focuses on the design of spatial experiments to optimize the amount of information derived from the experimental data and enhance the accuracy of the resulting causal effect estimator. We propose a surrogate function for the mean squared error of the estimator, which facilitates the use of classical graph cut algorithms to learn the optimal design. Our proposal offers three key advances: (1) it accommodates moderate to large spatial interference effects; (2) it adapts to different spatial covariance functions; (3) it is computationally efficient.
- Change your working directory to this main folder, run
setup.sh
to configure the environment and install all requirements. ./figure3b.sh
--> reproduce Figure 3(b)./figure6.sh
--> reproduce Figure 6./figure7.sh
--> reproduce Figure 7./figure8&9.sh
--> reproduce Figure 8 and Figure 9
Warm-up. If you can assess the spatial covariance, then you can employ oracel causal graph cut by following these steps:
### 1. configure the double robust estimator
from sklearn.ensemble import RandomForestRegressor
from semi_sp_design import SemiEstimator
model = RandomForestRegressor(random_state=0, n_estimators=10)
semi_est = SemiEstimator(n_splits=2, model=model)
### 2. get spatial clusters by (oracle) causal graph cut
from SemiGraphCut import multi_graph_cut
W = your_env.get_adj_matrix()
V = your_env.get_cov_matrix()
spat_cluster, _ = multi_graph_cut(W=W, V=V)
### 3. get the ATE estimation based on the cluster design
c_design = ClusterDesign(p=0.5, W=W, cluster=spat_cluster)
semi_est.update_design(c_design)
hat_tau_C, _ = semi_est.estimate(your_env, N=100)
print("Estimator:", hat_tau_C)
More realistic cases. Iteratively estimate spatial covariance via the causal graph cut by following these steps:
### configure the double robust estimator as previous
from sklearn.ensemble import RandomForestRegressor
from semi_sp_design import SemiEstimator
model = RandomForestRegressor(random_state=0, n_estimators=10)
semi_est = SemiEstimator(n_splits=2, model=model)
### perform the causal graph cut algorithm
from SemiGraphCut import online_graph_cut
online_graph_cut(your_env, semi_est)
hat_tau_C, _, _ = semi_est.estimate(your_env, N=100)
print("Estimator:", hat_tau_C)
Please cite our paper Balancing Interference and Correlation in Spatial Experimental Designs: A Causal Graph Cut Approach (ICML 2025)
@inproceedings{zhu2025balancing,
title={Balancing Interference and Correlation in Spatial Experimental Designs: A Causal Graph Cut Approach},
author={Zhu, Jin and Li, Jingyi and Zhou, Hongyi and Lin, Yinan and Lin, Zhenhua and Shi, Chengchun},
booktitle={International Conference on Machine Learning},
year={2025},
organization={PMLR}
}