forked from ThinkBigAnalytics/AoaDemoModels
-
Notifications
You must be signed in to change notification settings - Fork 0
/
training.R
66 lines (50 loc) · 2.08 KB
/
training.R
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
LoadPackages <- function() {
library("gbm")
library("DBI")
library("dplyr")
library("tdplyr")
}
suppressPackageStartupMessages(LoadPackages())
Connect2Vantage <- function() {
# Create Vantage connection using tdplyr
con <- td_create_context(host = Sys.getenv("AOA_CONN_HOST"),
uid = Sys.getenv("AOA_CONN_USERNAME"),
pwd = Sys.getenv("AOA_CONN_PASSWORD"),
dType = 'native'
)
# Set connection context
td_set_context(con)
con
}
train <- function(data_conf, model_conf, ...) {
print("Training model...")
# Connect to Vantage
con <- Connect2Vantage()
# Create tibble from table in Vantage
table <- tbl(con, data_conf$table)
# Create dataframe from tibble, selecting the necessary columns and mutating integer64 to integers
data <- table %>% select(c("NumTimesPrg", "PlGlcConc", "BloodP", "SkinThick", "TwoHourSerIns", "BMI", "DiPedFunc", "Age", "HasDiabetes")) %>%
mutate(NumTimesPrg = as.integer(NumTimesPrg), PlGlcConc = as.integer(PlGlcConc), BloodP = as.integer(BloodP), SkinThick = as.integer(SkinThick), TwoHourSerIns = as.integer(TwoHourSerIns), HasDiabetes = as.integer(HasDiabetes)) %>%
as.data.frame()
# Load hyperparameters from model configuration
hyperparams <- model_conf[["hyperParameters"]]
# Train model
model <- gbm(HasDiabetes~.,
data=data,
shrinkage=hyperparams$shrinkage,
distribution = 'bernoulli',
cv.folds=hyperparams$cv.folds,
n.trees=hyperparams$n.trees,
verbose=FALSE)
print("Model Trained!")
# Get optimal number of iterations
best.iter <- gbm.perf(model, plot.it=FALSE, method="cv")
# clean the model (R stores the dataset on the model..
model$data <- NULL
# how to save only best.iter tree?
# model$best.iter <- best.iter
# model$trees <- light$trees[best.iter]
# Save trained model
print("Saving trained model...")
saveRDS(model, "artifacts/output/model.rds")
}