-
Notifications
You must be signed in to change notification settings - Fork 0
/
masked_pretraining.R
75 lines (60 loc) · 1.9 KB
/
masked_pretraining.R
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
# Install development version of tabnet package
remotes::install_github("mlverse/tabnet")
library(tidymodels)
library(parameters)
library(skimr)
library(remotes)
library(tidyverse)
library(parallel)
library(doParallel)
library(vip)
library(themis)
library(lme4)
library(BradleyTerry2)
library(finetune)
library(butcher)
library(lobstr)
library(lubridate)
library(NHSRdatasets)
library(torch)
library(tabnet)
library(yardstick)
set.seed(777)
torch_manual_seed(777)
# Masked Pretraining. Process involves Unsupervised learning on a similar dataset first, then supervised training. We're creating our own pretrained model.
data("lending_club", package = "modeldata")
split <- initial_split(lending_club, strata = Class, prop = 9/10)
unsupervised <- training(split) %>% mutate(Class=NA) ## No class labels
supervised <- testing(split)
# recipe, prep and baking
prep_unsup <- recipe(Class ~ ., unsupervised) %>%
step_normalize(all_numeric()) %>%
prep()
unsupervised_baked_df <- prep_unsup %>%
bake(new_data=NULL)
## Unsupervised training first
pretrained_model <- tabnet_pretrain(x = unsupervised_baked_df %>% select(-Class), y = NULL, epochs = 25, valid_split = 0.2, verbose = TRUE)
# Now apply to supervised set
split_s <- initial_split(supervised, strate = Class)
train <- training(split_s)
supervised_train_df <- prep_unsup %>%
bake(new_data = train)
model_fit <- tabnet_fit(x = supervised_train_df %>% select(-Class),
y = supervised_train_df$Class,
tabnet_model = pretrained_model,
valid_split = 0.2,
epochs = 10,
verbose = TRUE)
# Explainability
# Unsupervised
pretrain_explain <- tabnet_explain(
pretrained_model,
new_data = unsupervised_baked_df
)
autoplot(pretrain_explain)
# Supervised
model_explain <- tabnet_explain(
model_fit,
new_data = supervised_train_df
)
autoplot(model_explain)