Skip to content

Commit

Permalink
update regu
Browse files Browse the repository at this point in the history
Signed-off-by: Ziyu-Mu <[email protected]>
  • Loading branch information
Ziyu-Mu committed May 21, 2024
1 parent 7654dcb commit 47dcd0d
Show file tree
Hide file tree
Showing 6 changed files with 111 additions and 7 deletions.
Binary file modified slides/regularization/figure/l1_reg_hess_02.png
100644 → 100755
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file removed slides/regularization/figure/wd-l2-geom.png
Binary file not shown.
30 changes: 26 additions & 4 deletions slides/regularization/rsrc/make_l1_reg_hess_plots.R
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
# ------------------------------------------------------------------------------

source("utils.R")
library(gridExtra)

prc <- prcomp(X , scale. = FALSE)
X_dc <- prc$x
Expand Down Expand Up @@ -73,7 +74,7 @@ plot_l1_theta2 <- init_plot_l1 +
linetype="dashed", alpha=0.8, size=1.1) +
# geom_polygon(data = data.frame(x = c(-Inf, Inf, Inf, -Inf),
# y = c(theta_hat[,2], theta_hat[,2], 0, 0)),
# aes(x,y), fill="white", alpha=0.5) +
# aes(x,y), fill="white", alpha=0.5) +
geom_point(data=as.data.frame(theta_hat_2), aes(x=theta_hat_2[1], y=theta_hat_2[2]), color="green", size=2) +
geom_segment(data=cbind(start=as.data.frame(theta_hat), end=as.data.frame(theta_hat_2)),
aes(x=start.V1, y=start.V2,
Expand All @@ -83,9 +84,30 @@ plot_l1_theta2 <- init_plot_l1 +
annotate("label", x=-3, y=2, label="frac(lambda, H[\"2,2\"])",
parse=TRUE, color='black', size=4, fill="yellow")

plot_l1_theta_lasso <- plot_l1_theta2 +
plot_l1_theta2_dash <- init_plot_l1 +
geom_hline(yintercept=lambda/hessian[2,2], colour="yellow",
linetype="dashed", alpha=0.8, size=1.1) +
# geom_polygon(data = data.frame(x = c(-Inf, Inf, Inf, -Inf),
# y = c(theta_hat[,2], theta_hat[,2], 0, 0)),
# aes(x,y), fill="white", alpha=0.5) +
geom_point(data=as.data.frame(theta_hat_1), aes(x=theta_hat_1[1], y=theta_hat_1[2]), color="green", size=2) +
geom_segment(data=cbind(start=as.data.frame(theta_hat), end=as.data.frame(theta_hat_1)),
aes(x=start.V1, y=start.V2,
xend=end.V1, yend=end.V2), colour="green",
size=1.1, linetype = 'dashed', arrow = arrow(ends="last", type="closed", length=unit(0.04, "npc")),
arrow.fill="green") +
geom_point(data=as.data.frame(theta_hat_2), aes(x=theta_hat_2[1], y=theta_hat_2[2]), color="green", size=2) +
geom_segment(data=cbind(start=as.data.frame(theta_hat), end=as.data.frame(theta_hat_2)),
aes(x=start.V1, y=start.V2,
xend=end.V1, yend=end.V2), colour="green",
size=1.1, linetype = 'dashed', arrow = arrow(ends="last", type="closed", length=unit(0.04, "npc")),
arrow.fill="green") +
annotate("label", x=-3, y=2, label="frac(lambda, H[\"2,2\"])",
parse=TRUE, color='black', size=4, fill="yellow")

plot_l1_theta_lasso <- plot_l1_theta2_dash +
geom_point(data=as.data.frame(theta_l1_reg), aes(x=theta_l1_reg[1], y=theta_l1_reg[2]), color="orange", size=2) +
geom_segment(data=cbind(start=as.data.frame(theta_hat_2), end=as.data.frame(theta_l1_reg)),
geom_segment(data=cbind(start=as.data.frame(theta_hat), end=as.data.frame(theta_l1_reg)),
aes(x=start.V1, y=start.V2,
xend=end.V1, yend=end.V2), colour="orange",
size=1.1, arrow = arrow(ends="last", type="closed", length=unit(0.04, "npc")),
Expand All @@ -96,4 +118,4 @@ plot_l1_theta_lasso <- plot_l1_theta2 +
p2 <- grid.arrange(plot_l1_theta2, plot_l1_theta_lasso, ncol=2)

ggsave("../figure/l1_reg_hess_01.png", plot = p1, height = 3.5, width = 5.5)
ggsave("../figure/l1_reg_hess_02.png", plot = p2, height = 3.5, width = 5.5)
ggsave("../figure/l1_reg_hess_02.png", plot = p2, height = 3.5, width = 5.5)
84 changes: 83 additions & 1 deletion slides/regularization/rsrc/make_l2_reg_hess_plots.R
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
# ------------------------------------------------------------------------------

source("utils.R")
library(gridExtra)

lambda <- 50
beta_start <- c(0, 0)
Expand Down Expand Up @@ -37,6 +38,7 @@ theta_min_ridge_data <- as.data.frame(t(Q %*% theta_min_skew))
x1 <- seq(-1.5,2,length.out = 100)
x2 <- seq(-1,3.5,length.out = 100)


# R_emp
init_cond_plot <- plot_r_emp(R_emp, x1, x2) +
annotate("label", x = 0.75, y = 3, label = "hat(theta)",
Expand Down Expand Up @@ -64,6 +66,22 @@ rot_plot <- plot_r_emp(R_emp, x1, x2) +

rs <- sapply(1:2, function(i) S[i,i] / (S[i,i] + lambda))

theta_hat <- theta_proj1_data*rs[1] + theta_proj2_data*rs[2]
geom_l2_plot <- plot_r_emp(R_emp, x1, x2) +
theme(legend.position="none") + coord_fixed() +
geom_hline(yintercept = 0, colour="darkgrey", size=1.2) +
geom_vline(xintercept = 0, colour="darkgrey", size=1.2) +
geom_point(aes(x=beta_true[1], y=beta_true[2], color="red", size=3)) +
geom_point(aes(x=theta_hat[1], y=theta_hat[2], color="yellow", size=3))

geom_l2_plot <- geom_l2_plot +
annotate("label", x = 1.3, y = 1.5, label = "hat(theta)[Ridge]",
parse = TRUE, color = 'black', size = 3, fill = "yellow") +
annotate("label", x = 0.75, y = 3, label = "hat(theta)",
parse = TRUE, color = 'black', size = 3, fill = "red")

##############shang

scale_rot_plot <- rot_plot +
geom_segment(data=cbind(start=as.data.frame(t(c(0,0))), end=
theta_proj1_data*rs[1] ), size=0.9,
Expand Down Expand Up @@ -102,6 +120,70 @@ p2 <- grid.arrange(rot_plot, init_cond_plot, ncol=2)

p3 <- grid.arrange(scale_rot_plot, scale_plot, ncol=2)

### contour plot for l2
#set a wider range
x1 <- seq(-2,2,length.out = 100)
x2 <- seq(-1,5,length.out = 100)

#calculate ellipse distance
dis_elli <- function(x, y, theta){
dr1 <- x - beta_true[1]
dr2 <- y - beta_true[2]
data <- cbind(dr1, dr2)
mat <- matrix(c(cos(theta), sin(theta), -sin(theta), cos(theta)), nrow=2)
dr <- data %*% mat
dr[,1] <- dr[,1]/3 #axis ~= 3:1
apply(dr, 1, dist)
}

# Generate data points for plotting circles(ridge)
cir_list <- list()
seq_data <- seq(0, 2*pi, length.out=100) #points for one circle
i <- 1
for(mul in c(0.15, 0.6, 0.9, 1.26)){ #adjust radius
cir_list[[i]] <- data.frame(x=cos(seq_data)*mul, y=sin(seq_data)*mul)
i <- i + 1
}

eval_grid <- expand.grid(x1,x2)
eval_grid$r_emp <- apply(eval_grid, 1, R_emp)

#preserve only center part of contour lines
#chose the parameter manually acoording to the plots
distance <- dis_elli(eval_grid[,1], eval_grid[,2], theta=-pi/3-0.014)
eval_grid$dist <- distance
eval_grid_sub <- subset(eval_grid, dist < 1.5)

p_elli <- ggplot() +
geom_raster(data=eval_grid, aes(x=Var1, y=Var2, fill=r_emp)) +
geom_contour(data=eval_grid_sub, aes(x=Var1, y=Var2, z=r_emp),
colour="white", bins=7) +
theme(legend.position="none") + coord_fixed() +
xlab(expression(theta[1])) +
ylab(expression(theta[2])) +
#geom_point(aes(x=theta_hat[1], y=theta_hat[2], color="yellow", size=3)) +
scale_fill_viridis(end = 0.9)

p_ridge <- p_elli +
geom_path(data=cir_list[[1]], aes(x, y), color="white", linetype="dashed") +
geom_path(data=cir_list[[2]], aes(x, y), color="white", linetype="dashed") +
geom_path(data=cir_list[[3]], aes(x, y), color="white", linetype="dashed") +
geom_path(data=cir_list[[4]], aes(x, y), color="white", linetype="dashed")


p4 <- p_ridge +
geom_point(aes(x=beta_true[1], y=beta_true[2]), color="red", size=3) +
geom_point(aes(x=0.73, y=1.03), color="yellow", size=3) +#intersection point
annotate("label", x = 1.1, y = 0.9, label = "hat(theta)[Ridge]",
parse = TRUE, color = 'black', size = 3, fill = "yellow") +
annotate("label", x = 0.75, y = 3, label = "hat(theta)",
parse = TRUE, color = 'black', size = 3, fill = "red") +
geom_hline(yintercept = 0, colour="darkgrey", size=1.2) +
geom_vline(xintercept = 0, colour="darkgrey", size=1.2) +
xlim(-1.4, 1.6) +
ylim(-1, 4.5)

ggsave("../figure/l2_reg_hess_01_plot.png", plot = p1, width = 5.5, height = 3.5, dpi="retina")
ggsave("../figure/l2_reg_hess_02_plot.png", plot = p2, width = 5.5, height = 3.5, dpi="retina")
ggsave("../figure/l2_reg_hess_03_plot.png", plot = p3, width = 5.5, height = 3.5, dpi="retina")
ggsave("../figure/l2_reg_hess_03_plot.png", plot = p3, width = 5.5, height = 3.5, dpi="retina")
ggsave("../figure/l2_reg_hess_04_plot.png", plot = p4, width = 3, height = 5, dpi="retina")
4 changes: 2 additions & 2 deletions slides/regularization/slides-regu-geom-l2.tex
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -145,8 +145,8 @@

\begin{figure}
\centering
\scalebox{0.9}{\includegraphics{figure/wd-l2-geom.png}}
\caption{\footnotesize The solid ellipses represent the contours of the unregularized objective and the dashed circles represent the contours of the $L2$ penalty. At $\hat{\thetab}_{\text{ridge}}$, the competing objectives reach an equilibrium.}
\scalebox{0.6}{\includegraphics{figure/l2_reg_hess_04_plot.png}}
\caption{\scriptsize The solid ellipses represent the contours of the unregularized objective and the dashed circles represent the contours of the $L2$ penalty. At $\hat{\thetab}_{\text{ridge}}$, the competing objectives reach an equilibrium.}
\end{figure}

\end{column}
Expand Down

0 comments on commit 47dcd0d

Please sign in to comment.