-
Notifications
You must be signed in to change notification settings - Fork 0
/
app.py
148 lines (127 loc) · 5.34 KB
/
app.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
import json
import numpy as np
import os
import pandas as pd
import streamlit as st
import seaborn as sns
import torch
from pages import features_page, data_fairness_page, model_bias_detection_page
def convert_to_tensor(df, columns, type='Float'):
arr = np.array(df[[*columns]]).astype(int)
tensor = torch.Tensor(arr)
if type == 'Long':
return tensor.type(torch.LongTensor)
return tensor
def predict(model, inputs, threshold=0.5):
"""
:param model: Torchscript model
:param inputs: Torch tensor (or tuple of Torch tensors) to be fed to model
:param threshold: Classification threshold value, default 0.5
:return pred_proba: Numpy array, probability predictions for class 1
:return y_pred: Numpy array, predicted labels (0/1) based on threshold
"""
with torch.no_grad():
model.eval()
pred_proba = model(*inputs)
# convert from tensor to np array
pred_proba = pred_proba.detach().cpu().numpy()
y_pred = [1 if i >= threshold else 0 for i in pred_proba]
return pred_proba, y_pred
def read_csv_list(file_list, selected=None):
# Check for user uploads
if not selected:
selected = st.sidebar.selectbox("Select one dataset for selection of features.",
options=[file.name for file in file_list],
index=0)
df_dict = {}
select_key = None
for file_path in file_list:
df = pd.read_csv(file_path)
if not isinstance(file_path, str):
file_path = file_path.name
key = os.path.basename(file_path)[:-4]
df_dict[key] = df
if file_path == selected:
select_key = key
return df_dict, select_key
def run_inference(file_list, df_dict, json_files): #TODO
# Loads list of models, runs inference and returns predictions
pred_dict = {}
for file_path in file_list:
# Define key from file_path
key = os.path.basename(file_path.name)[:-3]
# Get corresponding feature_dict
json_file = [file for file in json_files
if file.name[:-5] == key][0]
feature_dict = json.load(json_file)
# Get corresponding test data
test_df = df_dict[key]
x1_ts = convert_to_tensor(test_df, feature_dict.get('x1'), type='Long')
x2_ts = convert_to_tensor(test_df, feature_dict.get('x2'))
# Load model and get predictions
model = torch.jit.load(file_path)
pred_proba, y_pred = predict(model, (x1_ts, x2_ts))
test_df[feature_dict['y'][0]+'_prediction'] = y_pred
test_df[feature_dict['y'][0]+'_probability'] = pred_proba
pred_dict[key] = test_df
return pred_dict
def sidebar_handler(label, type_list, eg_dict):
# Example Use Case
eg_labels = list(eg_dict.keys())
st.sidebar.title('Example: NYC Subway Traffic')
eg_df_dict, eg_key = read_csv_list(eg_dict.values(), eg_dict[eg_labels[0]])
eg_df_dict_rep_key = dict(zip(eg_labels, eg_df_dict.values()))
example = ''
for dataset in eg_labels:
example += '- **%s**\n' % dataset
st.sidebar.markdown(example)
# User Upload
st.sidebar.title('Upload')
file_list = st.sidebar.file_uploader('%s, (%s)' % (label, ', '.join([type.upper() for type in type_list])),
type = type_list,
accept_multiple_files = True)
# Load Files
if file_list:
csv_files = [file for file in file_list if file.type in ['text/csv', 'application/vnd.ms-excel']]
pt_files = [file for file in file_list if file.type in ['application/octet-stream']]
json_files = [file for file in file_list if file.type in ['application/json']]
df_dict, select_key = read_csv_list(csv_files)
if len(type_list) > 1:
try:
# Run Inference
pred_dict = run_inference(pt_files, df_dict, json_files)
selected = pred_dict[select_key]
return pred_dict, select_key
except:
st.warning("Please ensure you have uploaded the corresponding model, test dataset and features json files with the same name for each model")
return eg_df_dict_rep_key, eg_labels[0]
return df_dict, select_key
else:
return eg_df_dict_rep_key, eg_labels[0]
# Config
st.set_page_config(page_title='FairWell',
layout='wide',
initial_sidebar_state='expanded')
# Sidebar
st.sidebar.title('FairWell')
page = st.sidebar.radio('Navigate',
options=['Guide',
'Feature Explorer',
'Data Fairness Assessment',
'Model Bias Detection & Mitigation'],
index=0)
# Title
st.title('FairWell')
# Pages
if page.lower() == 'guide':
about = open('README.md', 'r', encoding='utf8')
about = about.read().replace('./images/','https://raw.githubusercontent.com/FairWell-dev/FairWell/main/images/')[12:]
st.markdown(about, unsafe_allow_html=True)
elif page.lower() == 'feature explorer':
features_page.render(sidebar_handler)
elif page.lower() == 'data fairness assessment':
data_fairness_page.render(sidebar_handler)
elif page.lower() == 'model bias detection & mitigation':
model_bias_detection_page.render(sidebar_handler)
else:
st.text('Page ' + page + ' is not implemented.')