-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathREADME.Rmd
74 lines (58 loc) · 2.34 KB
/
README.Rmd
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
---
output: github_document
---
```{r, echo = FALSE}
knitr::opts_chunk$set(
collapse = TRUE,
comment = "#>",
fig.path = "man/figures/"
)
```
# Gaussian Process Regression
A toy Gaussian Process Regression example.
```{r, echo=F, message=F}
library(plgp)
n <- 5
X <- matrix(seq(0, 3 * pi, length = n), ncol = 1)
y <- sin(X)
D <- distance(X) # pairwise squared distances between x locations
eps <- sqrt(.Machine$double.eps)
Sigma <- exp(-D) + diag(eps, ncol(D)) # covariance matrix
XX <- matrix(seq(-0.5, 3 * pi + 0.8, length = 50), ncol = 1) # testing design matrix
DXX <- distance(XX) # distances between testing locations
SXX <- exp(-DXX) + diag(eps, ncol(DXX))
DX <- distance(XX, X) # distances between testing and training locations
SX <- exp(-DX)
Si <- solve(Sigma)
mup <- SX %*% Si %*% y # mean of the predictive distribution
Sigmap <- SXX - SX %*% Si %*% t(SX) # variance of the predictive distribution
YY <- rmvnorm(100, mup, Sigmap) # posterior/predictive distribution
q1 <- mup + qnorm(0.025, 0, sqrt(diag(Sigmap))) # pointwise quantile-based error-bars
q2 <- mup + qnorm(0.975, 0, sqrt(diag(Sigmap))) # pointwise quantile-based error-bars
### plot
par(mfrow = c(2, 2))
matplot(XX, t(YY), type = "l", col = "white", lty = 1, xlab = "x", ylab = "y",
main = "Gaussian Process Regression",
sub = "Observed data (black points)")
points(X, y, pch = 20, cex = 1.5)
matplot(XX, t(YY), type = "l", col = "white", lty = 1, xlab = "x", ylab = "y",
main = "True mean outcome (blue)",
sub = "Unobserved mean; to be estimated given the observed data")
points(X, y, pch = 20, cex = 1.5)
lines(XX, sin(XX), col = "blue")
matplot(XX, t(YY), type = "l", col = "white", lty = 1, xlab = "x", ylab = "y",
main = "Estimated mean (solid black)", sub = "95% probability intervals (dashed black)")
points(X, y, pch = 20, cex = 1.5)
lines(XX, sin(XX), col = "blue")
lines(XX, mup, lwd = 2)
lines(XX, q1, lwd = 2, lty = 2)
lines(XX, q2, lwd = 2, lty = 2)
matplot(XX, t(YY), type = "l", col = "gray", lty = 1, xlab = "x", ylab = "y",
main = "Samples from the predictive distribution (grey)",
sub = "95% of the grey lines fall within the dashed lines - as expected")
points(X, y, pch = 20, cex = 1.5)
lines(XX, sin(XX), col = "blue")
lines(XX, mup, lwd = 2)
lines(XX, q1, lwd = 2, lty = 2)
lines(XX, q2, lwd = 2, lty = 2)
```