Skip to content

Commit

Permalink
Merge pull request #217 from Ziyu-Mu/table_overfit
Browse files Browse the repository at this point in the history
overfitting table
  • Loading branch information
chriskolb authored Aug 6, 2024
2 parents 6db9559 + a69deea commit d0dc551
Show file tree
Hide file tree
Showing 4 changed files with 64 additions and 71 deletions.
60 changes: 0 additions & 60 deletions slides/regularization/rsrc/make_overfitting_table.R

This file was deleted.

52 changes: 52 additions & 0 deletions slides/regularization/rsrc/table_overfitting.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
# ------------------------------------------------------------------------------
# intro
# TABLE:
# train and test MSE table using Neural Network and CART (overfitting).
# DATA: mtcars
# ------------------------------------------------------------------------------

library(nnet)
library(xtable)
library(mlr3)
library(mlr3learners)
set.seed(123)

# DATA -------------------------------------------------------------------------

lgr::get_logger("mlr3")$set_threshold("info")

task = tsk("mtcars")

lrn1 = lrn("regr.nnet", size = 100, maxit = 20000, MaxNWts = 10000, decay = 0, abstol = 1e-7)
lrn1$encapsulate = c(train = "evaluate", predict = "evaluate")
lrn2 = lrn("regr.rpart", minsplit = 2, cp = 0)

my_learners = list(lrn1, lrn2)
for (x in my_learners){
x$predict_sets = c("train", "test")
}

bg = benchmark_grid(task, my_learners, rsmp("cv", folds = 10))

bmr = benchmark(bg)

m1 = msr("regr.mse", predict_sets = c("test"), id = "mse-test")
m2 = msr("regr.mse", predict_sets = c("train"), id = "mse-train")

a = bmr$aggregate(measures = list(m1, m2))
print(a)

# TABLE ------------------------------------------------------------------------

# Create a 2x2 comparison table with rounded results
res = as.data.frame(a)
res = res[, c("mse-train", "mse-test")]

rownames(res) = c("Neural Network", "CART")
colnames(res) = c("Train MSE", "Test MSE")

latex_tab = xtable(res)

print(latex_tab, file = "table_overfitting.tex", include.rownames = TRUE, include.colnames = TRUE, comment = FALSE)


11 changes: 11 additions & 0 deletions slides/regularization/rsrc/table_overfitting.tex
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
\begin{table}[ht]
\centering
\begin{tabular}{rrr}
\hline
& Train MSE & Test MSE \\
\hline
Neural Network & 1.47 & 345.84 \\
CART & 0.00 & 6.91 \\
\hline
\end{tabular}
\end{table}
12 changes: 1 addition & 11 deletions slides/regularization/slides-regu-intro.tex
Original file line number Diff line number Diff line change
Expand Up @@ -92,17 +92,7 @@

\lz \lz

\begin{table}[ht]
\centering
\begin{tabular}{rrr}
\hline
& Train MSE & Test MSE \\
\hline
Neural Network & 3.68 & 19.98 \\
CART & 0.00 & 10.21 \\
\hline
\end{tabular}
\end{table}
\input{rsrc/table_overfitting.tex}

\lz \lz

Expand Down

0 comments on commit d0dc551

Please sign in to comment.