-
Notifications
You must be signed in to change notification settings - Fork 0
/
tabnet_nhs_stranded.R
212 lines (161 loc) · 5.92 KB
/
tabnet_nhs_stranded.R
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
# Install development version of tabnet package
remotes::install_github("mlverse/tabnet")
library(tidymodels)
library(parameters)
library(skimr)
library(remotes)
library(tidyverse)
library(parallel)
library(doParallel)
library(vip)
library(themis)
library(lme4)
library(BradleyTerry2)
library(finetune)
library(butcher)
library(lobstr)
library(lubridate)
library(NHSRdatasets)
library(torch)
library(tabnet)
library(yardstick)
set.seed(777)
torch_manual_seed(777)
# Read in data ----
## a stranded patient is a patient that has been in hospital for longer than 7 days and we also call these Long Waiters.
strand_pat <- NHSRdatasets::stranded_data %>%
setNames(c("stranded_class", "age", "care_home_ref_flag", "medically_safe_flag",
"hcop_flag", "needs_mental_health_support_flag", "previous_care_in_last_12_month", "admit_date", "frail_descrip")) %>%
mutate(across(where(is.character),as.factor),
admit_date = as.Date(admit_date, format = "%d/%m/%Y"),
across(ends_with("flag"), as.logical))
# Explore data ----
## Analyse Class Imbalance
class_bal_table <- table(strand_pat$stranded_class)
prop_tab <- prop.table(class_bal_table)
upsample_ratio <- class_bal_table[2]/ sum(class_bal_table)
# Partition into training and test data splits ----
split <- initial_split(strand_pat)
train_data <- training(split)
test_data <- testing(split)
# Create Recipe ----
## Define Recipe to be applied to the dataset
stranded_rec <-
recipe(stranded_class ~ ., data = train_data) %>%
# Make a day of week and month feature from admit date and remove raw admit date
step_date(admit_date, features = c("dow", "month")) %>%
step_rm(admit_date) %>%
# Upsample minority (positive) class
themis::step_upsample(stranded_class, over_ratio = as.numeric(upsample_ratio)) %>%
step_zv(all_predictors()) %>%
step_normalize(all_numeric_predictors())
## Prepare and Bake recipe on training and test data
stranded_recipe_prep <- prep(stranded_rec, training = train_data)
stranded_train_bake <- bake(stranded_recipe_prep, new_data = NULL)
stranded_test_bake <- bake(stranded_recipe_prep, new_data = test_data)
# hyperparameter settings (apart from epochs) as per the TabNet paper (TabNet-S)
tabnet_model <- tabnet(epochs = 5, batch_size = 300, decision_width = tune(), attention_width = tune(),
num_steps = tune(), penalty = 0.000001, virtual_batch_size = 256, momentum = 0.6,
feature_reusage = 1.5, learn_rate = tune()) %>%
set_engine("torch", verbose = TRUE) %>%
set_mode("classification")
# Create Workflow to connect recipe and model
tabnet_workflow <- workflow() %>%
add_model(tabnet_model) %>%
add_recipe(stranded_rec)
# Specify parameter tuning grid
grid <-
tabnet_workflow %>%
tune::parameters() %>%
update(
decision_width = decision_width(range = c(20, 40)),
attention_width = attention_width(range = c(20, 40)),
num_steps = num_steps(range = c(4, 6)),
learn_rate = learn_rate(range = c(-2.5, -1))
) %>%
grid_max_entropy(size = 8)
# Parameter Tuning ----
## Make Cross Validation folds
folds <- vfold_cv(train_data, v = 5)
set.seed(777)
## Apply win/loss tuning method
res <- tabnet_workflow %>%
tune_race_win_loss(
resamples = folds,
grid = grid,
metrics = metric_set(roc_auc, accuracy),
control = control_race()
)
## View performance metrics across all hyperparameter permutations
res %>%
collect_metrics()
## Select the best model according to AUC
tabnet_best_model <- res %>%
select_best(metric = "roc_auc")
# Finalise the Model: Select best model ----
## Update the workflow with the model with the best hyperparameters (obtained from select_best())
final_tabnet_workflow <- tabnet_workflow %>%
finalize_workflow(res %>%
select_best(metric = "roc_auc"))
## Fit the final model to the training data
final_tabnet_model <- final_tabnet_workflow %>%
fit(data = train_data)
## Pull model from the workflow
final_tabnet_model %>%
extract_fit_parsnip()
## Predict from final model
final_tabnet_model %>%
predict(train_data, type = "prob")
# Fit the model to the test data ----
## Use last_fit() this function fits the finalised model on the full training dataset and evaluates the finalised model on the testing data
tabnet_fit_final <- final_tabnet_model %>%
last_fit(split)
## Metrics on test set
tabnet_fit_final %>%
collect_metrics()
## Predictions on test set
tabnet_fit_final %>%
collect_predictions() %>%
dplyr::select(starts_with(".pred")) %>%
bind_cols(test_data)
## Confusion Matrix on test set
tabnet_fit_final %>%
collect_predictions() %>%
conf_mat(stranded_class, .pred_class)
## Generate ROC Curve
roc_plot <-
tabnet_fit_final %>%
collect_predictions() %>%
roc_curve(stranded_class, '.pred_Not Stranded') %>%
autoplot()
# Variable Importance - Tabnet Default
## Extract final fitted workflow
tabnet_wf_model <- tabnet_fit_final$.workflow[[1]]
tabnet_explain <- tabnet_explain(
tabnet_wf_model %>% extract_fit_engine(),
new_data = stranded_test_bake
)
autoplot(tabnet_explain)
# Save Model and Metrics ----
## Save Model
### Measure object size of workflow
obj_size(tabnet_wf_model)
### Weigh in the workflow, the objects that are taking up the most memory
weigh(tabnet_wf_model)
### Butcher workflow to take up less space
tabnet_wf_model_reduced <- butcher::butcher(tabnet_wf_model)
### Check size difference
print(obj_size(tabnet_wf_model))
print(obj_size(tabnet_wf_model_reduced))
obj_size(tabnet_wf_model) - obj_size(tabnet_wf_model_reduced)
### Save model object as an RDS object
saveRDS(tabnet_wf_model_reduced, file = "./saved_models/tabnet_stranded.rds")
# Reading in workflow and predicting ----
## rm(rf_wf_model)
tabnet_wf_model <- readRDS(file = "./saved_models/tabnet_stranded.rds")
## Predict on test data with loaded workflow ----
test_sample <- test_data %>%
slice_sample(n = 50)
tabnet_wf_model %>%
predict(test_sample) %>%
cbind(stranded_class = test_sample$stranded_class)