2
2
import time
3
3
from .model import estimate_propensity , train
4
4
from .data import dataloaders
5
- from .utils import compute_cross_entropy , compute_distance_correlation , build_synthetic_outcomes , biased_treatment_effect , sigmoid
5
+ from .utils import compute_cross_entropy , compute_distance_correlation , build_synthetic_outcomes , biased_treatment_effect , sigmoid , sig_round
6
6
import numpy as np
7
7
import sklearn .preprocessing
8
8
import pandas as pd
@@ -17,29 +17,32 @@ def benchmark_table(variance, times, print_md=True, print_latex=False):
17
17
# Sometimes NaNs from one of the CATENet methods
18
18
# Remove NaNs from variance[method]
19
19
variance [method ] = [x for x in variance [method ] if not np .isnan (x )]
20
- mean = round (np .mean (variance [method ]))
21
- median = round (np .median (variance [method ]))
22
- upper = round (np .percentile (variance [method ], 75 ))
23
- lower = round (np .percentile (variance [method ], 25 ))
24
- times_mean = round (np .mean (times [method ]))
25
- table .append ([method , mean , lower , median , upper , times_mean ])
20
+ row = [method ]
21
+ mean = np .mean (variance [method ])
22
+ median = np .median (variance [method ])
23
+ upper = np .percentile (variance [method ], 75 )
24
+ lower = np .percentile (variance [method ], 25 )
25
+ times_mean = np .mean (times [method ])
26
+ to_add = [mean , lower , median , upper , times_mean ]
27
+ row += [sig_round (x ) for x in to_add ]
28
+ table .append (row )
26
29
27
30
if print_md :
28
31
print (tabulate (table , headers = ['Method' , 'Mean' , '1st Quartile' , '2nd Quartile' , '3rd Quartile' , 'Time (s)' ], tablefmt = "github" ))
29
32
30
33
cols = []
31
- for i in range (len (table [0 ])- 1 ):
32
- vals = [row [i + 1 ] for row in table ]
34
+ for i in range (1 , len (table [0 ])):
35
+ vals = [row [i ] for row in table ]
33
36
cols += [sorted (vals )]
34
37
for row in table :
35
38
print_row = [row [0 ]]
36
39
for idx in range (1 , len (row )):
37
40
if row [idx ] == cols [idx - 1 ][0 ]:
38
- print_row .append (r'\textbf{' + row [idx ]+ '}' )
41
+ print_row .append (r'\textbf{' + str ( row [idx ]) + '}' )
39
42
elif row [idx ] == cols [idx - 1 ][1 ]:
40
- print_row .append (r'\textit{\textbf{' + row [idx ]+ '}}' )
43
+ print_row .append (r'\textit{\textbf{' + str ( row [idx ]) + '}}' )
41
44
elif row [idx ] == cols [idx - 1 ][2 ]:
42
- print_row .append (r'\underline{\textbf{' + row [idx ]+ '}}' )
45
+ print_row .append (r'\underline{\textbf{' + str ( row [idx ]) + '}}' )
43
46
else :
44
47
print_row .append (row [idx ])
45
48
if print_latex :
@@ -158,13 +161,15 @@ def compute_estimates(methods, dataset, num_runs=10, train_fn=train, folder='',
158
161
159
162
output , times = {}, {}
160
163
with open (filename , 'r' ) as f :
161
- saved = eval (f .readline ())
162
- for method in saved :
163
- if method not in output :
164
- output [method ] = []
165
- times [method ] = []
166
- output [method ] += [saved [method ][0 ]]
167
- times [method ] += [saved [method ][1 ]]
164
+ for line in f :
165
+ line = line .replace ('Array(' , '' ).replace (', dtype=float32)' , '' )
166
+ saved = eval (line )
167
+ for method in saved :
168
+ if method not in output :
169
+ output [method ] = []
170
+ times [method ] = []
171
+ output [method ] += [float (saved [method ][0 ])]
172
+ times [method ] += [saved [method ][1 ]]
168
173
169
174
return output , times
170
175
@@ -192,14 +197,16 @@ def compute_variance_by_n(methods, dataset, ns, num_runs=10, train_fn=train, fol
192
197
193
198
output = {}
194
199
with open (filename , 'r' ) as f :
195
- saved = eval (f .readline ())
196
- for method in saved :
197
- if method not in output :
198
- output [method ] = {}
199
- n = saved ['n' ]
200
- if n not in output [method ]:
201
- output [method ][n ] = []
202
- output [method ][n ] += [saved [method ]]
200
+ for line in f :
201
+ line = line .replace ('Array(' , '' ).replace (', dtype=float32)' , '' )
202
+ saved = eval (line )
203
+ for method in saved :
204
+ if method not in output :
205
+ output [method ] = {}
206
+ n = saved ['n' ]
207
+ if n not in output [method ]:
208
+ output [method ][n ] = []
209
+ output [method ][n ] += [float (saved [method ])]
203
210
204
211
return output
205
212
@@ -236,14 +243,16 @@ def compute_variance_by_entropy(methods, dataset, noise_levels=[0, .2, .3, .4, .
236
243
237
244
output = {}
238
245
with open (filename , 'r' ) as f :
239
- saved = eval (f .readline ())
240
- for method in saved :
241
- if method not in output :
242
- output [method ] = {}
243
- cross_entropy = saved ['cross_entropy' ]
244
- if cross_entropy not in output [method ]:
245
- output [method ][cross_entropy ] = []
246
- output [method ][cross_entropy ] += [saved [method ]]
246
+ for line in f :
247
+ line = line .replace ('Array(' , '' ).replace (', dtype=float32)' , '' )
248
+ saved = eval (line )
249
+ for method in saved :
250
+ if method not in output :
251
+ output [method ] = {}
252
+ cross_entropy = saved ['cross_entropy' ]
253
+ if cross_entropy not in output [method ]:
254
+ output [method ][cross_entropy ] = []
255
+ output [method ][cross_entropy ] += [float (saved [method ])]
247
256
248
257
return output
249
258
@@ -291,13 +300,15 @@ def compute_variance_by_correlation(methods, dataset, alphas=[0, .15, .2, .25, .
291
300
292
301
output = {}
293
302
with open (filename , 'r' ) as f :
294
- saved = eval (f .readline ())
295
- for method in saved :
296
- if method not in output :
297
- output [method ] = {}
298
- correlation = saved ['correlation' ]
299
- if correlation not in output [method ]:
300
- output [method ][correlation ] = []
301
- output [method ][correlation ] += [saved [method ]]
303
+ for line in f :
304
+ line = line .replace ('Array(' , '' ).replace (', dtype=float32)' , '' )
305
+ saved = eval (line )
306
+ for method in saved :
307
+ if method not in output :
308
+ output [method ] = {}
309
+ correlation = saved ['correlation' ]
310
+ if correlation not in output [method ]:
311
+ output [method ][correlation ] = []
312
+ output [method ][correlation ] += [float (saved [method ])]
302
313
303
314
return output
0 commit comments