Skip to content

Commit b58110b

Browse files
author
R. Teal Witter
committed
table formatting fixes
1 parent 1e82b32 commit b58110b

File tree

4 files changed

+58
-46
lines changed

4 files changed

+58
-46
lines changed

naturalexperiments/benchmark.py

Lines changed: 54 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import time
33
from .model import estimate_propensity, train
44
from .data import dataloaders
5-
from .utils import compute_cross_entropy, compute_distance_correlation, build_synthetic_outcomes, biased_treatment_effect, sigmoid
5+
from .utils import compute_cross_entropy, compute_distance_correlation, build_synthetic_outcomes, biased_treatment_effect, sigmoid, sig_round
66
import numpy as np
77
import sklearn.preprocessing
88
import pandas as pd
@@ -17,29 +17,32 @@ def benchmark_table(variance, times, print_md=True, print_latex=False):
1717
# Sometimes NaNs from one of the CATENet methods
1818
# Remove NaNs from variance[method]
1919
variance[method] = [x for x in variance[method] if not np.isnan(x)]
20-
mean = round(np.mean(variance[method]))
21-
median = round(np.median(variance[method]))
22-
upper = round(np.percentile(variance[method], 75))
23-
lower = round(np.percentile(variance[method], 25))
24-
times_mean = round(np.mean(times[method]))
25-
table.append([method, mean, lower, median, upper, times_mean])
20+
row = [method]
21+
mean = np.mean(variance[method])
22+
median = np.median(variance[method])
23+
upper = np.percentile(variance[method], 75)
24+
lower = np.percentile(variance[method], 25)
25+
times_mean = np.mean(times[method])
26+
to_add = [mean, lower, median, upper, times_mean]
27+
row += [sig_round(x) for x in to_add]
28+
table.append(row)
2629

2730
if print_md:
2831
print(tabulate(table, headers=['Method', 'Mean', '1st Quartile', '2nd Quartile', '3rd Quartile', 'Time (s)'], tablefmt="github"))
2932

3033
cols = []
31-
for i in range(len(table[0])-1):
32-
vals = [row[i+1] for row in table]
34+
for i in range(1,len(table[0])):
35+
vals = [row[i] for row in table]
3336
cols += [sorted(vals)]
3437
for row in table:
3538
print_row = [row[0]]
3639
for idx in range(1, len(row)):
3740
if row[idx] == cols[idx-1][0]:
38-
print_row.append(r'\textbf{'+row[idx]+'}')
41+
print_row.append(r'\textbf{'+str(row[idx])+'}')
3942
elif row[idx] == cols[idx-1][1]:
40-
print_row.append(r'\textit{\textbf{'+row[idx]+'}}')
43+
print_row.append(r'\textit{\textbf{'+str(row[idx])+'}}')
4144
elif row[idx] == cols[idx-1][2]:
42-
print_row.append(r'\underline{\textbf{'+row[idx]+'}}')
45+
print_row.append(r'\underline{\textbf{'+str(row[idx])+'}}')
4346
else:
4447
print_row.append(row[idx])
4548
if print_latex:
@@ -158,13 +161,15 @@ def compute_estimates(methods, dataset, num_runs=10, train_fn=train, folder='',
158161

159162
output, times = {}, {}
160163
with open(filename, 'r') as f:
161-
saved = eval(f.readline())
162-
for method in saved:
163-
if method not in output:
164-
output[method] = []
165-
times[method] = []
166-
output[method] += [saved[method][0]]
167-
times[method] += [saved[method][1]]
164+
for line in f:
165+
line = line.replace('Array(', '').replace(', dtype=float32)', '')
166+
saved = eval(line)
167+
for method in saved:
168+
if method not in output:
169+
output[method] = []
170+
times[method] = []
171+
output[method] += [float(saved[method][0])]
172+
times[method] += [saved[method][1]]
168173

169174
return output, times
170175

@@ -192,14 +197,16 @@ def compute_variance_by_n(methods, dataset, ns, num_runs=10, train_fn=train, fol
192197

193198
output = {}
194199
with open(filename, 'r') as f:
195-
saved = eval(f.readline())
196-
for method in saved:
197-
if method not in output:
198-
output[method] = {}
199-
n = saved['n']
200-
if n not in output[method]:
201-
output[method][n] = []
202-
output[method][n] += [saved[method]]
200+
for line in f:
201+
line = line.replace('Array(', '').replace(', dtype=float32)', '')
202+
saved = eval(line)
203+
for method in saved:
204+
if method not in output:
205+
output[method] = {}
206+
n = saved['n']
207+
if n not in output[method]:
208+
output[method][n] = []
209+
output[method][n] += [float(saved[method])]
203210

204211
return output
205212

@@ -236,14 +243,16 @@ def compute_variance_by_entropy(methods, dataset, noise_levels=[0, .2, .3, .4, .
236243

237244
output = {}
238245
with open(filename, 'r') as f:
239-
saved = eval(f.readline())
240-
for method in saved:
241-
if method not in output:
242-
output[method] = {}
243-
cross_entropy = saved['cross_entropy']
244-
if cross_entropy not in output[method]:
245-
output[method][cross_entropy] = []
246-
output[method][cross_entropy] += [saved[method]]
246+
for line in f:
247+
line = line.replace('Array(', '').replace(', dtype=float32)', '')
248+
saved = eval(line)
249+
for method in saved:
250+
if method not in output:
251+
output[method] = {}
252+
cross_entropy = saved['cross_entropy']
253+
if cross_entropy not in output[method]:
254+
output[method][cross_entropy] = []
255+
output[method][cross_entropy] += [float(saved[method])]
247256

248257
return output
249258

@@ -291,13 +300,15 @@ def compute_variance_by_correlation(methods, dataset, alphas=[0, .15, .2, .25, .
291300

292301
output = {}
293302
with open(filename, 'r') as f:
294-
saved = eval(f.readline())
295-
for method in saved:
296-
if method not in output:
297-
output[method] = {}
298-
correlation = saved['correlation']
299-
if correlation not in output[method]:
300-
output[method][correlation] = []
301-
output[method][correlation] += [saved[method]]
303+
for line in f:
304+
line = line.replace('Array(', '').replace(', dtype=float32)', '')
305+
saved = eval(line)
306+
for method in saved:
307+
if method not in output:
308+
output[method] = {}
309+
correlation = saved['correlation']
310+
if correlation not in output[method]:
311+
output[method][correlation] = []
312+
output[method][correlation] += [float(saved[method])]
302313

303314
return output

naturalexperiments/estimators/catenet.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,5 +10,5 @@ def get_catenet_estimate(X, y, z, p, train_fn):
1010
t = catenet_models[model_name]()
1111
t.fit(X, y, w)
1212
cate_pred = t.predict(X)
13-
return cate_pred.mean()
13+
return float(cate_pred.mean())
1414
return get_catenet_estimate

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
setuptools.setup(
77
name="naturalexperiments",
8-
version="0.0.9",
8+
version="0.1.2",
99
author="R. Teal Witter",
1010
author_email="[email protected]",
1111
description="Estimators and datasets for treatment effect estimation in natural experiments.",

test.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,8 @@ def test(dataset, method_name):
1717
#for dataset in dataloaders:
1818
test(dataset, method_name)
1919
#compute_estimates(methods, dataset, num_runs=1, folder='output')
20-
#compute_variance(methods, dataset, num_runs=3, folder='output')
20+
#variance, times = compute_variance(methods, dataset, num_runs=0, folder='output')
21+
#benchmark_table(variance, times, print_md=True, print_latex=True)
2122
#compute_variance_by_n(methods, dataset, ns=[1000,3000,4000], num_runs=3, folder='output')
2223
#compute_variance_by_correlation(methods, dataset, num_runs=1, folder='output')
2324
#compute_variance_by_entropy(methods, dataset, num_runs=1, folder='output')

0 commit comments

Comments
 (0)