第 47 章 直線和曲線 curves and lines
兩種方式,一種是多項式回歸模型 (polynomial regression),另一種是 B-spline。
47.1 多項式回歸模型
使用二次,三次或者多次項構建預測變量的結構。
data("Howell1")
<- Howell1
d
plot( height ~ weight, d)
顯然我們有理由考慮使用二次方程的拋物線形式來做這個數據的模型:
\[ \mu_i = \alpha + \beta_1 x_i + \beta_2 x_i^2 \]
其完整的模型可以描述為:
\[ \begin{aligned} h_i & \sim \text{Normal}(\mu_u, \sigma) \\ \mu_i & = \alpha + \beta_1 x_i + \beta_2 x_i^2 \\ \alpha & \sim \text{Normal}(178, 20) \\ \beta_1 & \sim \text{Log-normal}(0, 1) \\ \beta_2 & \sim \text{Normal}(0, 1) \\ \sigma & \sim \text{Uniform}(0, 50) \end{aligned} \]
$weight_s <- (d$weight - mean(d$weight)) / sd(d$weight)
d$weight_s2 <- d$weight_s^2
d
.5 <- quap(
m4alist(
~ dnorm( mu, sigma),
height <- a + b1*weight_s + b2*weight_s2,
mu ~ dnorm( 178, 20 ),
a ~ dlnorm( 0, 1 ),
b1 ~ dnorm( 0, 1 ),
b2 ~ dunif(0 , 50)
sigma data = d
),
)
precis(m4.5)
## mean sd 5.5% 94.5%
## a 146.0574013 0.36897520 145.4677077 146.6470950
## b1 21.7330615 0.28888868 21.2713616 22.1947614
## b2 -7.8032547 0.27418364 -8.2414531 -7.3650563
## sigma 5.7744673 0.17646462 5.4924428 6.0564919
<- seq( from = -2.2, to = 2, length.out = 30)
weight.seq <- list(weight_s = weight.seq, weight_s2 = weight.seq^2)
pred_dat <- link( m4.5, data = pred_dat)
mu <- apply(mu, 2, mean)
mu.mean <- apply( mu, 2, PI, prob = 0.89)
mu.PI <- sim( m4.5, data = pred_dat)
sim.height <- apply( sim.height, 2, PI, prob = 0.89) height.PI
plot( height ~ weight_s, d, col = col.alpha(rangi2, 0.5))
lines( weight.seq, mu.mean )
shade( mu.PI, weight.seq)
shade(height.PI, weight.seq)
你可以更進一步,使用一個三次方(cubic)模型來看看是否接近實際數據本身的分佈。
\[ \begin{aligned} h_i & \sim \text{Normal}(\mu_u, \sigma) \\ \mu_i & = \alpha + \beta_1 x_i + \beta_2 x_i^2 + \beta_3x_i^3\\ \alpha & \sim \text{Normal}(178, 20) \\ \beta_1 & \sim \text{Log-normal}(0, 1) \\ \beta_2 & \sim \text{Normal}(0, 1) \\ \beta_3 & \sim \text{Normal}(0, 1) \\ \sigma & \sim \text{Uniform}(0, 50) \end{aligned} \]
$weight_s3 <- d$weight_s^3
d.6 <- quap(
m4alist(
~ dnorm( mu, sigma),
height <- a + b1*weight_s + b2*weight_s2 + b3*weight_s3,
mu ~ dnorm( 178, 20 ),
a ~ dlnorm( 0, 1 ),
b1 ~ dnorm( 0, 1 ),
b2 ~ dnorm( 0, 1 ),
b3 ~ dunif(0 , 50)
sigma data = d
),
)precis(m4.6)
## mean sd 5.5% 94.5%
## a 146.3945281 0.30998676 145.8991094 146.8899468
## b1 15.2197442 0.47626470 14.4585812 15.9809071
## b2 -6.2026158 0.25715787 -6.6136037 -5.7916278
## b3 3.5833594 0.22877313 3.2177358 3.9489831
## sigma 4.8298822 0.14694209 4.5950404 5.0647241
<- seq( from = -2.2, to = 2, length.out = 30)
weight.seq <- list(weight_s = weight.seq,
pred_dat weight_s2 = weight.seq^2,
weight_s3 = weight.seq^3)
<- link( m4.6, data = pred_dat)
mu <- apply(mu, 2, mean)
mu.mean <- apply( mu, 2, PI, prob = 0.89)
mu.PI <- sim( m4.6, data = pred_dat)
sim.height <- apply( sim.height, 2, PI, prob = 0.89)
height.PI
plot( height ~ weight_s, d, col = col.alpha(rangi2, 0.5))
lines( weight.seq, mu.mean )
shade( mu.PI, weight.seq)
shade(height.PI, weight.seq)
在使用標準化的 Z 值計算之後,如果你希望恢復到原來的尺度,而不是標準化的 x 軸,可以按照以下步驟:
plot( height ~ weight_s, d, col = col.alpha(rangi2, 0.50), xaxt = "n")
<- c(-2, -1, 0, 1, 2)
at <- at*sd(d$weight) + mean(d$weight)
labels axis( side = 1, at = at, labels = round(labels, 1))
47.2 平滑曲線 Splines
B-Splines 是基礎 (basic) 平滑曲線的涵義。下面的數據紀錄了每年春天櫻花開放的日期:
data("cherry_blossoms")
<- cherry_blossoms
d precis(d)
## mean sd 5.5% 94.5% histogram
## year 1408.0000000 350.88459641 867.77000 1948.23000 ▇▇▇▇▇▇▇▇▇▇▇▇▁
## doy 104.5405079 6.40703618 94.43000 115.00000 ▁▂▅▇▇▃▁▁
## temp 6.1418861 0.66364787 5.15000 7.29470 ▁▃▅▇▃▂▁▁
## temp_upper 7.1851512 0.99292057 5.89765 8.90235 ▁▂▅▇▇▅▂▂▁▁▁▁▁▁▁
## temp_lower 5.0989413 0.85034959 3.78765 6.37000 ▁▁▁▁▁▁▁▃▅▇▃▂▁▁▁
其中每年第一次確認的開花的日期被紀錄為 doy
變量。它的範圍是在每年的第 86 (三月底) ~124 (五月初)天:
range(d$doy, na.rm = TRUE)
## [1] 86 124
plot(d$year, d$doy)
為了評估這個每年第一天開花紀錄的日期是否隨著時間有怎樣的變化趨勢,我們選擇使用基礎平滑曲線的方法:
\[ \mu_i = \alpha + w_1 B_{i,1} + w_2 B_{i,2} + w_3 B_{i,3} + \dots \]
其中,
- \(B_{i,n}\) 是橫軸第 \(i\) 年份的第 \(n\) 個區間的基礎函數 (basis function)
- \(w_n\) 是該基礎函數本身的權重
假定我們選擇使用 15 個節點 (knots),也就是 \(n = 15\),來繪製該數據的平滑曲線:
<- d[complete.cases(d$doy), ] # complete cases on doy
d2 <- 15
num_knots <- quantile( d2$year, probs = seq(0, 1, length.out = num_knots))
knot_list
str(knot_list)
## Named num [1:15] 812 1036 1174 1269 1377 ...
## - attr(*, "names")= chr [1:15] "0%" "7.1428571%" "14.285714%" "21.428571%" ...
# 15 dates knot_list
## 0% 7.1428571% 14.285714% 21.428571% 28.571429% 35.714286% 42.857143% 50% 57.142857%
## 812 1036 1174 1269 1377 1454 1518 1583 1650
## 64.285714% 71.428571% 78.571429% 85.714286% 92.857143% 100%
## 1714 1774 1833 1893 1956 2015
我們來建立一個三次方平滑曲線,cubic spline:
<- bs(d2$year,
B knots = knot_list[ - c(1, num_knots)] ,
degree = 3, intercept = TRUE )
head(B)
## 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17
## [1,] 1.00000000 0.000000000 0.00000000000 0.0000000e+00 0 0 0 0 0 0 0 0 0 0 0 0 0
## [2,] 0.96035713 0.039312303 0.00032983669 7.2860303e-07 0 0 0 0 0 0 0 0 0 0 0 0 0
## [3,] 0.76650948 0.220745951 0.01255948103 1.8509216e-04 0 0 0 0 0 0 0 0 0 0 0 0 0
## [4,] 0.56334070 0.385673722 0.04938483524 1.6007409e-03 0 0 0 0 0 0 0 0 0 0 0 0 0
## [5,] 0.54526700 0.398683678 0.05418946879 1.8598537e-03 0 0 0 0 0 0 0 0 0 0 0 0 0
## [6,] 0.45273210 0.459759692 0.08371386242 3.7943487e-03 0 0 0 0 0 0 0 0 0 0 0 0 0
B
其實是一個 827 行,17 列的矩陣。每一行是一個年份,對應 d2
數據框的每一行。每一列,則是對應一個基礎函數 basic function。為了繪製這些基礎函數,我們可以把他們和對應的年份作圖:
plot(NULL, xlim = range(d2$year), ylim = c(0, 1),
xlab = "year", ylab = "basis")
for( i in 1:ncol(B)) lines( d2$year, B[, i])
模型本身只是一個簡單線型回歸其實:
\[ \begin{aligned} D_i & \sim \text{Normal}(\mu_i, \sigma) \\ \mu_i & = \alpha + \sum_{k = 1}^K w_k B_{k,i} \\ \text{Priors :} & \\ \alpha & \sim \text{Normal}(100, 10) \\ w_j & \sim \text{Normal}(0, 10) \\ \sigma & \sim \text{Exponential} (1) \end{aligned} \] 對於像方差,標準差這樣必須大於零的參數來說,選用指數分佈作為先驗概率分佈其實是比較常見的。
實際運算這個模型,獲取 \(w_j\) 們的事後概率分佈:
.7 <- quap(
m4alist(
~ dnorm( mu, sigma ),
D <- a + B %*% w ,
mu ~ dnorm(100, 10),
a ~ dnorm(0, 10),
w ~ dexp(1)
sigma data = list(D = d2$doy, B = B),
), start = list( w = rep(0, ncol(B)))
)
precis(m4.7, depth = 2)
## mean sd 5.5% 94.5%
## w[1] -3.01911288 3.86119146 -9.190042571 3.15181682
## w[2] -0.82921519 3.87017151 -7.014496741 5.35606636
## w[3] -1.05526308 3.58495744 -6.784717475 4.67419131
## w[4] 4.84837372 2.87712842 0.250166821 9.44658062
## w[5] -0.83559661 2.87435170 -5.429365788 3.75817256
## w[6] 4.32740988 2.91486378 -0.331105417 8.98592517
## w[7] -5.31866630 2.80023011 -9.793974851 -0.84335776
## w[8] 7.85246858 2.80208766 3.374191300 12.33074587
## w[9] -1.00465423 2.88105863 -5.609142366 3.59983391
## w[10] 3.04274790 2.91011772 -1.608182285 7.69367808
## w[11] 4.66654932 2.89173401 0.044999865 9.28809877
## w[12] -0.14583740 2.86942487 -4.731732549 4.44005774
## w[13] 5.56171681 2.88744009 0.947029859 10.17640376
## w[14] 0.72146307 2.99932938 -4.072044567 5.51497070
## w[15] -0.80243201 3.29351897 -6.066111436 4.46124742
## w[16] -6.96037274 3.37577948 -12.355520340 -1.56522514
## w[17] -7.66646272 3.22276627 -12.817065657 -2.51585978
## a 103.34562113 2.36974606 99.558309230 107.13293304
## sigma 5.87659679 0.14375221 5.646852994 6.10634058
繪製增加了權重 \(w_j\) 和基礎函數結合的平滑曲線:
<- extract.samples( m4.7 )
post <- apply( post$w, 2, mean )
w
plot(NULL, xlim = range(d2$year), ylim = c(-6, 6),
xlab = "year", ylab = "basis * weight")
for( i in 1 : ncol(B)) lines( d2$year, w[i]*B[,i])
<- link( m4.7 )
mu dim(mu)
## [1] 1000 827
<- apply( mu, 2, PI, 0.97)
mu_PI dim(mu_PI)
## [1] 2 827
<- apply( mu, 2, mean)
mu_mean
plot( d2$year, d2$doy, col = col.alpha(rangi2, 0.3), pch = 16)
lines(d2$year, mu_mean)
shade(mu_PI, d2$year, col = col.alpha("black", 0.5))
圖形顯示的開花時間和年份之間的曲線關係,可見在1500年前後發生了某些情況。而且近年來似乎有提早開花的傾向。