第 55 章 貝葉斯多層回歸模型 multilevel models

Many statistical models also have anterograde amnesia. As the models move from one cluster - individual, group, location - in the data to another, estimating parameters for each cluster, they forget everything about the previous cluster. … These models implicitly assume that nothing learned about any one category informs estimates for the other categories – the parameters are independent of one another and learn from completely separate proportions of the data. This would be like forgetting you had ever been in a cafe, each time you go to a new cafe. Cafes do differ, but they are also alike.

Richard McElreath

其實,當我們開始使用回歸模型時,最推薦的就是從多層回歸模型入手,把它當作一種應該實施的默認選項。當然的確非多層回歸的簡單模型在一些場合下就能夠勝任數據分析的過程給出滿意的結果,但事實上更多時候你會發現多層回歸模型會更加出色的幫助我們理解這個世界。所以最好的狀態其實是,我們先從多層回歸模型入手分析數據,隨着分析的深入,過程中我們可能發現不再需要多層模型結構就能完成分析任務。這其實好過我們從一開始就忽略掉了多層回歸模型的這一關鍵的可能性。

55.1 多層數據實例:蝌蚪和青蛙數據 multilevel tadpoles

data("reedfrogs")
d <- reedfrogs
str(d)
## 'data.frame':    48 obs. of  5 variables:
##  $ density : int  10 10 10 10 10 10 10 10 10 10 ...
##  $ pred    : Factor w/ 2 levels "no","pred": 1 1 1 1 1 1 1 1 2 2 ...
##  $ size    : Factor w/ 2 levels "big","small": 1 1 1 1 2 2 2 2 1 1 ...
##  $ surv    : int  9 10 7 10 9 9 10 9 4 9 ...
##  $ propsurv: num  0.9 1 0.7 1 0.9 0.9 1 0.9 0.4 0.9 ...

現在我們只關心上述數據中生存下來的蝌蚪數量 surv,和開始時的蝌蚪數量 density。該數據包涵了很多的變化和不確定性,或者叫做方差 variance。這些變化和不確定性可能來自不同的實驗條件,或者未知的原因。所以,假設每一行數據中的10只蝌蚪,被放在了不同的水池裏,也就是說,上面的數據中我們有48個水池做重複的實驗。於是該數據就可以被理解爲是重複相似的實驗,但是每次的實驗又有一些微妙的不同。每一個水池,就是一個數據的層級 ‘cluster’。如果我們忽略這個層級的概念,我們可能就忽略掉了他們本身在實驗開始之時的基線生存狀況 (baseline survival) 本身可能存在的不確定性 (variation)。這個不確定性,或者叫基線生存狀況的方差可能掩蓋住一些重要的發現。如果我們允許每個水池擁有自己單獨的其實狀態,也就是函數的截距,但是假如僅僅使用啞變量的方法 dummy variable,那其實我們就掉進了進行性健忘症的陷阱裏。因爲雖然他們是不同的水池做的實驗,但是一個水池的結果其實是能提示或者告訴我們其他水池的實驗結果的一些信息的,而不是完全地相互獨立毫無關聯性。

所以我們需要的其實是一個同時能夠允許每個水池的蝌蚪生存擁有自己的起始狀態,也就是函數的截距,且同時考慮到他們之間是有關聯性的,也就是這些截距之間是有一定的方差的。這樣的模型就被叫做隨機截距模型 (varing intercepts models),這樣的模型是最簡單的多層回歸模型。下面的模型用於預測每個不同的水池中實驗過後蝌蚪的生存狀況 (mortality) :

\[ \begin{aligned} S_i & \sim \text{Binomial}(N_i, p_i) \\ \text{logit}(p_i) & = \alpha_{\text{TANK}[i]} & [\text{unique log-odds for each tank}] \\ \alpha_j & = \text{Normal}(0, 1.5) & \text{for } j = 1, \dots, 48 \end{aligned} \]

這個模型很容易可以編碼成爲 Stan 模型:

# make the tank cluster variable
d$tank <- 1:nrow(d)

dat <- list(
  S = d$surv, 
  N = d$density, 
  tank = d$tank
)
# approximate posterior

m13.1 <- ulam(
  alist(
    S ~ dbinom( N, p ), 
    logit(p) <- a[tank], 
    a[tank] ~ dnorm(0, 1.5)
  ), data = dat, chains = 4, log_lik = TRUE
)

saveRDS(m13.1, "../Stanfits/m13_1.rds")
m13.1 <- readRDS("../Stanfits/m13_1.rds")
precis(m13.1, depth = 2)
##               mean         sd         5.5%        94.5%     n_eff      Rhat4
## a[1]   1.706135516 0.75126999  0.559612202  2.980799190 3565.4398 0.99914084
## a[2]   2.418378688 0.89508779  1.085855358  3.917982001 4098.7241 0.99870974
## a[3]   0.755285062 0.62854994 -0.229874740  1.808953868 5844.2221 0.99874415
## a[4]   2.398939738 0.87583172  1.064463796  3.892148654 3893.4841 0.99826780
## a[5]   1.711930366 0.78109476  0.538386909  3.059664189 4282.7151 1.00000591
## a[6]   1.712400845 0.76663422  0.556389436  2.986639219 3697.2250 0.99865881
## a[7]   2.433901051 0.96823924  0.983975047  4.068994913 3683.5712 0.99851080
## a[8]   1.728234678 0.77691481  0.579123478  3.070141912 2552.7945 1.00061031
## a[9]  -0.378787854 0.63017246 -1.411607834  0.601533811 4502.5692 0.99857748
## a[10]  1.704566280 0.74267187  0.563784136  2.922636900 4936.8884 0.99871181
## a[11]  0.759081120 0.63171344 -0.223449189  1.861229403 4293.1364 0.99840630
## a[12]  0.369263021 0.60895344 -0.575697499  1.317297190 5681.1496 0.99882110
## a[13]  0.770336180 0.64974231 -0.215174796  1.833888814 4366.6278 0.99923073
## a[14] -0.009729068 0.61875932 -0.994355905  1.010911262 5918.5699 0.99849235
## a[15]  1.703282129 0.74387855  0.573729989  2.946188044 4247.8544 0.99910230
## a[16]  1.743809558 0.80544535  0.559255295  3.079851444 4219.7553 0.99904420
## a[17]  2.549011923 0.68607770  1.554237307  3.689000199 4117.3536 0.99855216
## a[18]  2.150779207 0.61378539  1.244894345  3.194931728 4290.9635 0.99873434
## a[19]  1.818262637 0.55673559  0.992956540  2.740192008 6254.5908 0.99845271
## a[20]  3.099216840 0.79046813  1.922768458  4.458605238 3870.3141 0.99860113
## a[21]  2.141388712 0.59123864  1.278037726  3.133666741 3778.8639 0.99947441
## a[22]  2.143619964 0.63146412  1.192457415  3.182062745 3763.2670 0.99863539
## a[23]  2.137536255 0.62219243  1.174864938  3.181079132 4368.0224 0.99870273
## a[24]  1.536828664 0.52728459  0.761942359  2.375794268 4735.8741 0.99837462
## a[25] -1.105415076 0.43496285 -1.851966059 -0.435036245 3482.7299 0.99877398
## a[26]  0.076751268 0.40535395 -0.555377180  0.711271683 4550.4709 0.99868506
## a[27] -1.540868852 0.48401432 -2.359101123 -0.802931346 3882.5694 0.99842876
## a[28] -0.545748574 0.40591978 -1.193630669  0.101944128 4532.7825 0.99856686
## a[29]  0.075142168 0.38799070 -0.520229947  0.696214974 4880.8001 0.99878543
## a[30]  1.325013361 0.48797180  0.579469610  2.140616726 3983.1118 0.99869221
## a[31] -0.726599283 0.41815857 -1.414086716 -0.081095374 4260.4117 0.99865729
## a[32] -0.393374796 0.40388534 -1.026337857  0.233711720 4460.4798 1.00030775
## a[33]  2.851106794 0.66962348  1.863101375  3.976619545 4075.9763 0.99975478
## a[34]  2.439434481 0.53666179  1.636527087  3.330027234 3611.3375 0.99851844
## a[35]  2.462536698 0.58911895  1.585811555  3.493084708 4159.2509 0.99855632
## a[36]  1.894453091 0.47246927  1.166990383  2.700956176 3864.1837 0.99932504
## a[37]  1.897154130 0.47606314  1.193316198  2.690424743 4514.2042 0.99864540
## a[38]  3.371519718 0.80242385  2.209244462  4.757121346 3328.1092 0.99919367
## a[39]  2.466813623 0.56569064  1.614435107  3.428984467 4161.5867 0.99940885
## a[40]  2.164210136 0.52798452  1.379469542  3.097029995 2913.2789 0.99942866
## a[41] -1.901610019 0.47229892 -2.724517479 -1.207876441 4606.9855 0.99941201
## a[42] -0.631228672 0.33261100 -1.175636261 -0.103720359 4309.2499 0.99891339
## a[43] -0.508448320 0.34917306 -1.088393845  0.045662208 4497.9745 0.99878071
## a[44] -0.392124218 0.34118815 -0.948064494  0.143533849 4602.5032 0.99906926
## a[45]  0.507571894 0.34208208 -0.037475944  1.061806146 4501.6383 0.99816571
## a[46] -0.633508160 0.36474922 -1.228174857 -0.061322724 4888.1314 0.99845844
## a[47]  1.912347263 0.48151725  1.209880492  2.741920346 3341.1668 0.99929986
## a[48] -0.051961695 0.34246837 -0.612204797  0.498751869 6207.4457 0.99858292

你會看見模型的運算結果是告訴我們 48 個池塘本身的基線生存狀況,也就是有 48 個截距。但是 m13.1 並不是一個多層回歸模型,下面的模型中關鍵部分的加入才使得這個模型變得更加有意義:

\[ \begin{aligned} S_i & \sim \text{Binomial}(N_i, p_i) \\ \text{logit}(p_i) & = \alpha_{\text{TANK}[i]} \\ \alpha_j & \sim \text{Normal}(\color{blue}{\bar{\alpha}, \sigma}) & [\text{adaptive prior}] \\ \color{blue}{\bar{\alpha}} & \color{blue}{\sim \text{Normal}(0, 1.5)} & [\text{prior for average tank}] \\ \color{blue}{\sigma} & \color{blue}{\sim \text{Exponential}(1)} & [\text{prior for standard deviation of tanks}] \end{aligned} \]

上述模型中值得注意的是,除了允許不同水池的基線生存狀況,也就是截距可以各不相同,我們還允許這些截距之間存在聯繫。也就是這些截距本身是服從正(常)態分佈的,該正(常)態分佈的均值是 \(\bar{\alpha}\),標準差是 \(\sigma\)。這個截距服從的正(常)態分佈的參數,也有自己的先驗概率分佈。我們把這樣的參數叫做超參數 hyperparameters,他們是參數的參數,他們的先驗概率分佈被叫做超先驗 hyperpriors。我們可以用下面的代碼來運行這個模型:

m13.2 <- ulam(
  alist(S ~ dbinom( N, p ), 
  logit(p) <- a[tank], 
  a[tank] ~ dnorm(a_bar, sigma), 
  a_bar ~ dnorm( 0, 1.5 ), 
  sigma ~ dexp( 1 )
  ), data = dat, chains = 4, log_lik = TRUE
)
saveRDS(m13.2, "../Stanfits/m13_2.rds")

先比較一下這兩個模型之間的模型信息差別:

m13.2 <- readRDS("../Stanfits/m13_2.rds")
compare( m13.1, m13.2 )
##            WAIC        SE     dWAIC       dSE     pWAIC       weight
## m13.2 198.87517 7.3372677  0.000000        NA 20.403646 0.9997639405
## m13.1 215.57755 4.5643729 16.702381 4.1259847 26.092822 0.0002360595

從兩個模型之間的比較結果來看,首先,m13.2 只有 21 個有效的參數,比起實際的參數個數 50 個少了很多。這是因爲對這些截距增加了超參數的限制之後,他們受到了更多的約束,更加趨近於彼此。我們可以看看這個模型給出的截距分佈的超參數的事後概率分佈估計:

precis(m13.2, depth = 2, pars = c("a_bar", "sigma"))
##            mean         sd      5.5%     94.5%     n_eff      Rhat4
## a_bar 1.3393115 0.26442441 0.9230349 1.7758856 3022.9740 0.99984557
## sigma 1.6125652 0.20782693 1.3171320 1.9735443 1846.8192 1.00098472

這裡的截距分佈的超參數的事後估計其實給出了十分精確的估計,其均值在 1.34 左右,標準差是1.62,這說明了不同的水池之間的關係十分近似。也就是說,我們使用這個多層回歸模型,讓模型自己從數據中去學習並獲得截距和截距之間的關係。這比起一開始我們自己給 m13.1 設定的標準差 1.5 還要激進。於是這個多層回歸模型事實上給模型參數的估計增加了更多的限制。

為了加深我們對這個激進的超參數的理解,我們把這兩個模型 m13.1, m13.2 給出的估計結果繪製成圖形來觀察:

# extract Stan Samples
post <- extract.samples(m13.2)
# post <- extract.samples(m13.1)


# compute mean intercept for each tank
d$propsurv.est <- logistic( apply( post$a, 2, mean ))

#  display raw proportions surviving in each tank
plot( d$propsurv, ylim = c(0, 1), pch = 16, xaxt = 'n', 
      xlab = 'tank', ylab = 'proportion survival', 
      col = rangi2, bty = "n")
axis(1, at = c(1, 16, 32, 48), labels = c(1, 16, 32, 48))


# overlay posterior means
points( d$propsurv.est )

# mark posterior mean probability across tanks 
abline( h = mean(inv_logit(post$a_bar)), lty = 2)

# draw vertical dividers between tank densities
abline( v = 16.5, lwd = 0.5 )
abline( v = 32.5, lwd = 0.5 )
text( 8, 0, 'small tanks')
text( 16 + 8, 0, 'medium tanks')
text( 32 + 8, 0, 'large tanks')
Empirical proportions of surviviors in each tadpolle tank, shown by the filled blue points, plotted with the 48 per-tank parameters from the multilevel model, shown by the black circles. The dashed line locates the average proportion of survivors across all tanks. The vertical lines divide tanks with different initial densities of tadpoles: small tanks (10 tadpoles), medium tanks (25), and large tanks (35). In every tank, the posterior mean from the multilevel model is closer to the dashed line than the empirical proportion is. This reflects the pooling of information across tanks, to help with inference.

圖 55.1: Empirical proportions of surviviors in each tadpolle tank, shown by the filled blue points, plotted with the 48 per-tank parameters from the multilevel model, shown by the black circles. The dashed line locates the average proportion of survivors across all tanks. The vertical lines divide tanks with different initial densities of tadpoles: small tanks (10 tadpoles), medium tanks (25), and large tanks (35). In every tank, the posterior mean from the multilevel model is closer to the dashed line than the empirical proportion is. This reflects the pooling of information across tanks, to help with inference.

55.1 中,橫軸是水池的編號,從左往右依次是從 1 到 48 號水池;縱軸是水池中蝌蚪生存下來的比例。圖中藍色的點是原始數據點,也就是實際觀察值 propsurv。黑色的鏤空點則是模型估計的每個水池的截距。水平的橫虛線是估計的所有水池的蝌蚪存活概率的平均值 \(\alpha\)。而圖中的縱向的實線是把水池按照實驗開始時的蝌蚪密度計算的不同類型的池子,從小,中,到大三種類型的池子,各16個。不難注意到我們能看見多層回歸模型給出的推測值都相對觀察值更靠近總體平均生存概率。看起來似乎是黑色鏤空的圓點都更加靠近數據分佈的中心,平均值附近。這種現象又被叫做縮水現象 shrinkage,這是由於增加了超參數之後的多層回歸模型的參數估計受到的限制性的調整 regularization。其次,我們也發現在圖左側,也就是起始蝌蚪密度較小的水池裏,多層回歸模型估計的生存概率值更加靠近總體平均值,也就是縮水得更加明顯,距離觀察數據比密度大的水池要遠。也就是說,在小的水池裏,我們更加容易發現模型估計值和觀察值之間的差別,但是在其實密度大的水池中,觀察值和模型估計值更加接近。最後,如果藍色的點離虛線的總體均值越遠,它和黑色點,多層回歸模型估計值之間的差別越大。

上述三種現象其實是在告訴我們一件很重要的事,也就是把信息綜合起來的話,每一個層級的參數估計都會受益得到提升和改善 (pooling information across clusters to improve estimates) 。這裏的綜合信息 pooling information 的意義是,每一個水池的數據,每一個層級的數據,都含有能提高和改善其他層級參數估計信息 each tank provides information that can be used to improve the estimates for all of the other tanks。這是因爲我們假設了每個水池的截距 log-odds 雖有變化但不獨立而是服從某個正(常)態分佈。有了這個分佈的假設,貝葉斯估計就能幫助我們共享信息給不同的數據層級。

那麼,模型估計的這些青蛙的總體生存概率的分佈是怎樣的呢?我們可以從它對應的模型事後概率分佈中獲得結果繪製成圖。我們先繪製事後概率分佈中前100個 \(\alpha, \sigma\) 組合的平均存活率的分佈:

# show first 100 populations in the posterior 
plot( NULL, xlim = c(-3, 4), ylim = c(0, 0.35), 
      xlab = "log-odds survive", 
      ylab = "Density", bty = "n")

for( i in 1:100 ) 
  curve( dnorm(x, post$a_bar[i], post$sigma[i]), add = TRUE, 
         col = col.alpha("black", 0.2))
The inferred population of survival across tanks. 100 Gaussian distributions of the log-odds of survival, sampled from the posteriro of m13.2.

圖 55.2: The inferred population of survival across tanks. 100 Gaussian distributions of the log-odds of survival, sampled from the posteriro of m13.2.

55.2 告訴我們,均值 \(\alpha\) 和它對應的標準差 \(\sigma\) 都是有相當程度不確定性的 (uncertainty)。

# sample 8000 imaginary tanks from the posterior distribution 
sim_tanks <- rnorm(80000, post$a_bar, post$sigma)

# transform to probability and visualize
dens( inv_logit(sim_tanks), lwd = 2, adj = 0.1, 
      xlab = "probability survive", bty = "n")
Survival probabilities for 8000 new simulated tanks, averaging over the posterior distribution in previous figure.

圖 55.3: Survival probabilities for 8000 new simulated tanks, averaging over the posterior distribution in previous figure.

55.2 多層回歸的變化的效應和過度擬合/過低擬合之間的交易 varing effects and the underfitting/overfitting trade-off

使用多層回歸模型使得模型可以估計不同(截距或者斜率)的效應,其最大的好處是能夠給出更加準確的估計。其原理就是使用混合效應模型其實使得模型儘量避免了過度/過低擬合。如果建立模型是爲了預測未知池塘中蝌蚪存活概率的話,我們可以有三種策略:

  1. 完全合併策略,complete pooling。這方法其實是把總體的水池蝌蚪生存率這一數據看作是不變的 the population survival probability of ponds is invariant,也就是固定一個截距,適用所有的池塘。
  2. 完全不合併策略,no pooling。這方法類似 m13.1 模型的方案,無視水池和水池之間可能存在的相關性,把每個水池都看作的獨立互不影響的。也就是進行性失憶 模型。
  3. 部分合併策略,partial pooling。這方法其實是 m13.2 模型的多層回歸模型方案,通過允許水池之間有相關性,使模型自行學習應有的超參數。

很顯然,第一個方案其實很不切合實際,雖然把所有的數據都彙總到一個點上,但是指望用這唯一的一個估計結果來適用所有的水池的生存概率,認爲所有的水池都會產生相同的結果是不符合現實情況的,這一方案認爲水池和水池之間生存概率不會有差別,沒有變化和靈活性, all ponds are identical。第二個方案則是另一個極端,認爲每一個水池都給出完全不同的結果,即使相同也是偶然,水池之間毫無關聯。圖 55.1 中的藍色實心點就是這樣的模型。用於估計每個點的位置的數據在第二個方案下都會變得很少,所以每一個估計都變得更加不精確。第三個方案就是多層回歸模型增加的混合效應,它的部分合併策略其實是第一個和第二個方案的折衷辦法,使得模型的估計給出更加靈活的結果,也更適合擴展到預測未知數據,同時避免了過度/過低擬合。

爲了展示這個效果,我們可以用計算機模擬一些生存率的蝌蚪水池數據作爲已知的結果,用不同的模型來分析獲得其估計,從而直觀地理解這三種方案的不同思路。

55.2.1 用於產生模擬數據的模型 the model

第一步是設計我們希望產生數據的模型。我們可以直接使用 m13.2 的模型的主體部分:

\[ \begin{aligned} S_i & \sim \text{Binomial}(N_i, p_i) \\ \text{logit}(p_i) & = \alpha_{\text{POND}[i]} \\ \alpha_j & \sim \text{Normal}(\bar{\alpha}, \sigma) \\ \bar{\alpha} & \sim \text{Normal}(0, 1.5) \\ \sigma & \sim \text{Exponential}(1) \end{aligned} \]

為了能順利從這個模型中產生模擬數據,我們需要給模型中賦予真實值的參數有:

  1. \(\bar{\alpha}\) 是總體池塘蝌蚪存活概率的平均值的對數比值 average log-odds of survival in the entire population of ponds。
  2. \(\sigma\) 是總體池塘蝌蚪存活概率平均值的對數比值的標準差 the standard deviation of the distribution of log-odds of survival among ponds。
  3. \(\alpha\) 一系列水池的蝌蚪存過概率的對數比值的真實值,作為模型的變動截距 a vector of individual pond intercepts, one for each pond。

此外,我們還需要設定每個水池的蝌蚪起始樣本量 \(N_i\),這些都設定完畢之後,就只剩下每個水池可能存活的蝌蚪數量了,這個可以使用二項分佈的隨機值來設定。

a_bar <- 1.5
sigma <- 1.5
nponds <- 60
Ni <- as.integer( rep(c(5, 10, 25, 35), each = 15 ))

我們設定了60個水池,起始蝌蚪的數量分別是 5,10,25,35 的池子各有15個。另外 \(\bar{\alpha}, \sigma\) 定義了我們設計下的總體存活率的對數比值 log-odds 的分佈特徵。接下來就是讓計算機生成符合這個分佈條件 \(\text{Normal}(\bar{\alpha}, \sigma)\) 的 60 個水池的存活率的對數比值作為各自的截距。

set.seed(5005) 
a_pond <- rnorm( nponds, mean = a_bar, sd = sigma)

上面的代碼生成了60個符合設定的均值和標準差的數據,作為每個水池的蝌蚪存活率的對數比值。最後,把這些數據合併成為一個數據框:

dsim <- data.frame( pond = 1:nponds, Ni = Ni, true_a = a_pond)
str(dsim)
## 'data.frame':    60 obs. of  3 variables:
##  $ pond  : int  1 2 3 4 5 6 7 8 9 10 ...
##  $ Ni    : int  5 5 5 5 5 5 5 5 5 5 ...
##  $ true_a: num  0.567 1.99 -0.138 1.857 3.912 ...

生成的數據框 dsim 有三個變量,一個是水池編號,一個是水池起始蝌蚪數量,一個是真實的存活概率的對數比值 (log-odds)。

55.2.2 模擬存活概率結果 simulate survivors

根據我們設定的每個水池的“真實”存活概率的對數比值,我們不難計算每個水池的“真實”存活概率:

\[ p_i = \frac{\exp(\alpha_i)}{1 +\exp(\alpha_i)} \]

使用 logistic 函數可以方便的計算並且讓計算機模擬一系列該水池的蝌蚪存活數量:

dsim$Si <- rbinom( nponds, prob = logistic(dsim$true_a), size = dsim$Ni)
str(dsim)
## 'data.frame':    60 obs. of  4 variables:
##  $ pond  : int  1 2 3 4 5 6 7 8 9 10 ...
##  $ Ni    : int  5 5 5 5 5 5 5 5 5 5 ...
##  $ true_a: num  0.567 1.99 -0.138 1.857 3.912 ...
##  $ Si    : int  4 5 1 5 5 5 2 5 4 4 ...

55.2.3 計算完全不合併策略 no-pooling estimates

在這個模型設定下,最簡單快速的是計算完全不合併策略時的結果。這可以直接從前面計算機生成的實驗數據計算獲得。先計算每個水池中蝌蚪的存活概率:

dsim$p_nopool <- dsim$Si / dsim$Ni
head(dsim)
##   pond Ni      true_a Si p_nopool
## 1    1  5  0.56673123  4      0.8
## 2    2  5  1.99002317  5      1.0
## 3    3  5 -0.13775688  1      0.2
## 4    4  5  1.85676651  5      1.0
## 5    5  5  3.91208800  5      1.0
## 6    6  5  1.95414869  5      1.0

數據 dsim 的新增一列 p_nopool 就是每個水池實際觀察到的蝌蚪存活概率。這個計算結果等同於我們把每個水池當作一個啞變量互無關聯時給出的模型估計結果。

55.2.4 計算部分合併策略的結果 partial-pooling estimates

我們來使用 Stan 運行這個部分合併結果的模型

dat <- list(
  Si = dsim$Si, 
  Ni = dsim$Ni, 
  pond = dsim$pond
)
m13.30 <- ulam(
  alist(
    Si ~ dbinom( Ni, p ), 
    logit(p) <- a_pond[pond], 
    a_pond[pond] ~ dnorm( a_bar, sigma ), 
    a_bar ~ dnorm( 0, 1.5 ), 
    sigma ~ dexp( 1 )
  ), data = dat, chains = 4
)
saveRDS(m13.30, "../Stanfits/m13_30.rds")

上面的模型運行計算的就是最基礎版本的隨機截距模型。我們來看一下它給出的 \(\bar{\alpha}, \sigma\) 的事後分佈情況,下面的結果包含了六十個水池的截距,會很長:

m13.30 <- readRDS("../Stanfits/m13_30.rds")
precis( m13.30, depth = 2 )
##                    mean         sd          5.5%       94.5%     n_eff      Rhat4
## a_pond[1]   1.650407110 0.96904063  0.1840157397  3.21610041 3602.5566 0.99872074
## a_pond[2]   2.834338499 1.24148870  1.0007366114  4.92653420 2882.0704 0.99911023
## a_pond[3]  -0.643249556 0.91047753 -2.1226558797  0.73631508 2442.9210 1.00118607
## a_pond[4]   2.879393705 1.23392475  1.0753994158  5.01899499 2140.7584 0.99862131
## a_pond[5]   2.850206711 1.27463172  0.9832639699  4.99317825 1818.6398 1.00009094
## a_pond[6]   2.875036377 1.27603932  1.0240878520  5.13249706 2580.0892 0.99906937
## a_pond[7]   0.073569339 0.80574369 -1.2100708197  1.33815259 3043.3444 0.99860653
## a_pond[8]   2.845285499 1.24419299  1.0925164731  5.04294899 3061.6630 0.99967991
## a_pond[9]   1.640764609 1.04090309  0.0658312282  3.40013596 3786.2020 0.99898298
## a_pond[10]  1.662719698 1.00690244  0.0942168850  3.33404818 2884.8516 0.99881739
## a_pond[11]  2.854369003 1.27008868  0.9986433001  4.88275188 2595.0291 0.99981048
## a_pond[12]  0.078809355 0.84263594 -1.2662538706  1.46245910 3379.5942 0.99913852
## a_pond[13]  2.841399020 1.23604483  1.0306836717  4.98853447 2082.3706 1.00009111
## a_pond[14]  2.885347058 1.24766958  1.1012827128  4.95147914 2010.1180 1.00160532
## a_pond[15]  2.838645518 1.22543550  1.0330515180  4.94891601 2733.4150 1.00068080
## a_pond[16]  1.575398583 0.75674601  0.4307555558  2.80348477 3558.0680 0.99881879
## a_pond[17] -1.441357255 0.72675686 -2.6746649858 -0.36628831 2295.6003 0.99957747
## a_pond[18]  1.069732355 0.65145174  0.0336161365  2.13106736 3718.5782 0.99905032
## a_pond[19] -0.940599103 0.67353423 -2.0642510405  0.12758493 4180.4317 0.99851252
## a_pond[20]  1.571571268 0.78323302  0.3953309826  2.85302461 3979.0358 0.99827203
## a_pond[21] -0.142395459 0.62805421 -1.1595189249  0.85277506 4230.9275 0.99920164
## a_pond[22]  2.259831678 0.93055697  0.9101710635  3.82472671 2625.5858 1.00019856
## a_pond[23]  3.241016009 1.09565777  1.6854998943  5.17061513 2012.3850 0.99952697
## a_pond[24]  0.606054685 0.63343113 -0.3899463101  1.61247305 5142.2846 0.99837767
## a_pond[25]  3.279252138 1.15974972  1.6087975632  5.29849803 2262.1498 0.99906149
## a_pond[26]  2.218857703 0.90489171  0.8620718729  3.79051487 2577.0982 1.00011192
## a_pond[27]  1.032862125 0.71037078 -0.0514290807  2.22168735 2789.1514 0.99941969
## a_pond[28]  2.247840392 0.89076760  0.9912766740  3.77271073 2158.9564 1.00095656
## a_pond[29]  1.566335229 0.76329714  0.4490445116  2.88024205 2990.4098 0.99957137
## a_pond[30]  1.043066174 0.71390861 -0.0550604298  2.24023090 4217.3874 1.00012562
## a_pond[31]  2.460298686 0.67071150  1.4600395439  3.59187574 2605.1400 0.99919053
## a_pond[32]  2.056988630 0.58082689  1.1738653532  3.03319837 3417.6131 0.99846202
## a_pond[33]  1.719997357 0.52738590  0.9109309041  2.60832973 5587.4489 0.99851302
## a_pond[34]  1.256999633 0.49178056  0.5168072533  2.03843119 3268.9943 0.99924019
## a_pond[35]  0.674469822 0.42408930  0.0064163076  1.36286305 3414.4603 0.99919911
## a_pond[36]  3.835241545 1.11351640  2.3118751443  5.80523282 1615.1664 1.00131024
## a_pond[37] -1.001666846 0.45192860 -1.7457914860 -0.32383094 4309.0830 0.99947060
## a_pond[38] -1.181771717 0.45084523 -1.9246579225 -0.47367188 3350.4592 0.99890768
## a_pond[39]  0.677536497 0.44100771 -0.0185801594  1.39114768 3663.6990 0.99867024
## a_pond[40]  3.860786217 1.12076950  2.2692482127  5.79948052 1667.3130 0.99854679
## a_pond[41]  3.821365313 1.06524390  2.3305627294  5.67578489 1612.7689 1.00105065
## a_pond[42]  2.462269046 0.69314717  1.4258509184  3.61833311 3186.8495 0.99857627
## a_pond[43] -0.128696920 0.40218054 -0.7690797367  0.50362840 3589.9168 0.99902865
## a_pond[44]  0.657109234 0.40292342  0.0236895055  1.29586999 3980.6607 0.99949000
## a_pond[45] -1.186495921 0.46040340 -1.9314086555 -0.48683175 4384.5500 0.99898465
## a_pond[46]  0.010147737 0.33077455 -0.5051117179  0.52945609 3516.9707 0.99883160
## a_pond[47]  4.065550464 1.03658464  2.6073678621  5.79930831 1435.8616 1.00103136
## a_pond[48]  2.079263555 0.50984644  1.3320833116  2.96382734 3284.9157 1.00002749
## a_pond[49]  1.857357020 0.48234328  1.1198986835  2.64213548 3374.6112 0.99971553
## a_pond[50]  2.787020417 0.68170929  1.7895947754  3.96695002 2601.4123 0.99863578
## a_pond[51]  2.406458965 0.58533949  1.5477499282  3.35287318 3604.8986 0.99978555
## a_pond[52]  0.357155211 0.33339316 -0.1620370866  0.88372445 4181.1260 0.99934405
## a_pond[53]  2.106903052 0.50035530  1.3361915299  2.93145339 3158.0118 0.99925148
## a_pond[54]  4.060400285 1.01624187  2.6345088811  5.87636234 1617.3083 1.00044270
## a_pond[55]  1.130150172 0.38675503  0.5261511543  1.78230406 2975.5885 0.99861533
## a_pond[56]  2.779640587 0.66678682  1.8432349230  3.83522304 2634.3398 0.99956022
## a_pond[57]  0.713532057 0.35419243  0.1587824074  1.28652167 3868.9928 0.99843233
## a_pond[58]  4.044509994 1.01564448  2.6091145979  5.84183473 2089.2851 0.99919628
## a_pond[59]  1.648179868 0.44517530  0.9570171345  2.39649263 2929.4914 0.99939655
## a_pond[60]  2.411574826 0.61119724  1.4920258316  3.43769741 2202.5257 0.99918695
## a_bar       1.659983578 0.24579290  1.2777780128  2.05546494 1231.0486 1.00115140
## sigma       1.671424362 0.23676423  1.3222544730  2.08017324  641.1696 1.00665994

很好,接下來就可以運算每個水池的模型預測存活概率,並且添加到我們預先設定好的實驗數據中去。爲了便於比較,先要計算真實的水池中蝌蚪的存活概率。最後一步,就是計算模型預測的存活率,和實際真實存活率之間的差距了,也叫模型估計誤差。然後把這兩個條件下的估計誤差進行繪製在同一張圖上直觀地比較:

m13.30 <- readRDS("../Stanfits/m13_30.rds")
post <- extract.samples( m13.30 )
dsim$p_partpool <- apply( inv_logit(post$a_pond), 2, mean)

dsim <- dsim %>% 
  mutate(p_true = inv_logit(true_a),
         nopool_error = abs(p_nopool - p_true), 
         partpool_erro = abs(p_partpool - p_true))

# or similarly you can use basic R command
# nopool_erro <- abs( dsim$p_nopool - dsim$p_true )
# partpool_erro <- abs( dsim$p_partpool - dsim$p_true )


plot( 1:60, dsim$nopool_error,
      xlab = "Pond", 
      ylab = "absolute error", bty = "n",
      col = rangi2, pch = 16, 
      ylim = c(0, 0.6))
points( 1:60, dsim$partpool_erro )


# mark posterior mean probability across tanks 
error_avg <- dsim %>%
  group_by(Ni) %>% 
  summarise(nopool_avg = mean(nopool_error), 
            partpool_avg = mean(partpool_erro))
segments(1, error_avg$nopool_avg[1], 
         16, error_avg$nopool_avg[1], 
         col = rangi2, lwd = 2)
segments(1, error_avg$partpool_avg[1], 
         16, error_avg$partpool_avg[1], 
         col = "black", lwd = 2, lty = 2)
segments(17, error_avg$nopool_avg[2], 
         32, error_avg$nopool_avg[2], 
         col = rangi2, lwd = 2)
segments(17, error_avg$partpool_avg[2], 
         32, error_avg$partpool_avg[2], 
         col = "black", lwd = 2, lty = 2)
segments(33, error_avg$nopool_avg[3], 
         46, error_avg$nopool_avg[3], 
         col = rangi2, lwd = 2)
segments(33, error_avg$partpool_avg[3], 
         46, error_avg$partpool_avg[3], 
         col = "black", lwd = 2, lty = 2)
segments(47, error_avg$nopool_avg[4], 
         60, error_avg$nopool_avg[4], 
         col = rangi2, lwd = 2)
segments(47, error_avg$partpool_avg[4], 
         60, error_avg$partpool_avg[4], 
         col = "black", lwd = 2, lty = 2)

# draw vertical dividers between tank densities
abline( v = 16.5, lwd = 0.5 )
abline( v = 32.5, lwd = 0.5 )
abline( v = 46.5, lwd = 0.5 )
text( 8, 0.6, 'Tiny ponds (5)')
text( 16 + 8, 0.6, 'Small ponds (10)')
text( 32 + 8, 0.6, 'Medium ponds (25)')
text( 46 + 8, 0.6, 'Large ponds (35)')
Error of no-pooling and partial pooling estimates, for the simulated tadpole ponds. The horizontal axis displays pond number. The vertical axis measures the absolute error in the predicted proportion of survivors, compared to the true value used in the simulation. The higher the point, the worse the estimate. No-pooling shown in blue. Partial pooling shown in black. The blue and dashed black lines show the average error for each kind of estimate, across each initial density of tadpoles (pond size). Smaller ponds porduce more error, but the partial pooling estimates are better on average, especially in smaller ponds.

圖 55.4: Error of no-pooling and partial pooling estimates, for the simulated tadpole ponds. The horizontal axis displays pond number. The vertical axis measures the absolute error in the predicted proportion of survivors, compared to the true value used in the simulation. The higher the point, the worse the estimate. No-pooling shown in blue. Partial pooling shown in black. The blue and dashed black lines show the average error for each kind of estimate, across each initial density of tadpoles (pond size). Smaller ponds porduce more error, but the partial pooling estimates are better on average, especially in smaller ponds.

從圖 55.4 中我們首先能夠直接一眼就觀察到的重要信息就是,這兩種方案,一個是完全不合併方案,一個是部分合併方案,無論是哪一種,其實對樣本量大的水池(圖中靠右側的水池)的存活概率估計誤差都比較低。這主要是因爲樣本量越多,估計得越精確。而樣本量較小的水池,途中靠左側的水池中,由於蝌蚪數量有限,即使是部分合併方案使用的隨機截距模型也給出比較大的誤差。其次,藍色線 (完全不合併方案) 幾乎總是在黑色虛線 (隨機截距模型,部分合併方案) 的上方,或者二者在大樣本時,會十分接近。當然隨機截距並不總是更加優越,只是在許許多多的計算中,從長遠來看 (in the long run) 隨機截距模型給出的結果誤差會平均地比較小。第三,藍色線和黑色虛線之間的差距在樣本量越小時,越明顯。也就是說,同樣因爲小樣本會造成結果有估計誤差,隨機截距模型給出的誤差要相對小一些。

那麼,从計算機的模擬計算結果中,我們學到了什麼?記得圖 55.1 中我們見過樣本量越小的池塘的模型結果更加靠近樣本均值的虛線,也就是縮水更加嚴重。但是從計算機模擬的結果來看的話,樣本量越小的池塘的存活概率估計結果是隨機截距模型能給出更加小的誤差估計。這兩個現象並不是偶然發生的。樣本量小的水池,傾向於發生模型的過度擬合 overfitting。由於樣本量較小的水池蘊含的信息量較少,所以它們的模型估計結果更加容易受到樣本均值的影響。也就是被其他樣本量更多的水池的數據的影響。當一個個的水池本身各自的樣本量都相對較多時,你可能會認爲隨機截距或者叫多層回歸模型能給出的估計結果優化就很有限。事實上即便是每個數據層級本身的樣本量也比較大的情況下,使用多層回歸模型來計算也沒有任何壞處。大樣本量的一些層級的估計結果有可能有助於改善較小樣本量層級的結果的預測以及參數的估計結果。所以,平均地看,其實始終應該使用隨機效應模型,也就是部分合併方案的策略,因爲它總是能提供較優的結果估計,而且能夠從數據本身學習獲得應該使用的超參數等用於調節 (regularization) 模型的估計和運行。

下面的代碼有助於我們重複使用已經運行過的模型,減少計算機重複運算的壓力。當你想要重複上述計算機模擬過程的時候,可能會希望讓模型運行其他的模擬數據,採集新的事後分佈樣本:

a <- 1.5
sigma <- 1.5
nponds <- 60
Ni <- as.integer( rep(c(5, 10, 25, 35), each = 15 ))
set.seed(12345)
a_pond <- rnorm( nponds, mean = a, sd = sigma)
dsim <- data.frame( pond = 1:nponds, 
                    Ni = Ni, 
                    true_a = a_pond)

dsim$Si <- rbinom( nponds, prob = inv_logit( dsim$true_a ), size = dsim$Ni )
dsim$p_nopool <-  dsim$Si / dsim$Ni
newdat <-  list(Si = dsim$Si, 
                Ni = dsim$Ni, 
                pond = 1:nponds)

m13.3new <- stan( fit = m13.30@stanfit, 
                  data = newdat, 
                  chains = 4 )
saveRDS(m13.3new, "../Stanfits/m13_3new.rds")

一旦你的計算機已經運行好了一個模型,m13.30,那麼假如只需要修改模型的樣本數據,模型結構不需要改變的話,使用上述的方法會大大提升新模型的運行速度,並且保存結果在 m13.3new 裏面。然後你只需要使用類似的方法重新繪製新數據給出的新結果,而不需要每次再重頭運行模型本身。需要重複利用的模型運算結果已經存儲在了每個 stan 模型中的 stanfit 部分。只要給它新的相同變量的數據框,它就能迅速給出新的事後概率分佈結果。這是非常有用的技巧。

m13.3new <- readRDS("../Stanfits/m13_3new.rds")
post <- extract.samples( m13.3new )
dsim$p_partpool <- apply( inv_logit(post$a_pond), 2, mean)

dsim <- dsim %>% 
  mutate(p_true = inv_logit(true_a),
         nopool_error = abs(p_nopool - p_true), 
         partpool_erro = abs(p_partpool - p_true))

# or similarly you can use basic R command
# nopool_erro <- abs( dsim$p_nopool - dsim$p_true )
# partpool_erro <- abs( dsim$p_partpool - dsim$p_true )


plot( 1:60, dsim$nopool_error,
      xlab = "Pond", 
      ylab = "absolute error", bty = "n",
      col = rangi2, pch = 16, 
      ylim = c(0, 0.6))
points( 1:60, dsim$partpool_erro )


# mark posterior mean probability across tanks 
error_avg <- dsim %>%
  group_by(Ni) %>% 
  summarise(nopool_avg = mean(nopool_error), 
            partpool_avg = mean(partpool_erro))
segments(1, error_avg$nopool_avg[1], 
         16, error_avg$nopool_avg[1], 
         col = rangi2, lwd = 2)
segments(1, error_avg$partpool_avg[1], 
         16, error_avg$partpool_avg[1], 
         col = "black", lwd = 2, lty = 2)
segments(17, error_avg$nopool_avg[2], 
         32, error_avg$nopool_avg[2], 
         col = rangi2, lwd = 2)
segments(17, error_avg$partpool_avg[2], 
         32, error_avg$partpool_avg[2], 
         col = "black", lwd = 2, lty = 2)
segments(33, error_avg$nopool_avg[3], 
         46, error_avg$nopool_avg[3], 
         col = rangi2, lwd = 2)
segments(33, error_avg$partpool_avg[3], 
         46, error_avg$partpool_avg[3], 
         col = "black", lwd = 2, lty = 2)
segments(47, error_avg$nopool_avg[4], 
         60, error_avg$nopool_avg[4], 
         col = rangi2, lwd = 2)
segments(47, error_avg$partpool_avg[4], 
         60, error_avg$partpool_avg[4], 
         col = "black", lwd = 2, lty = 2)

# draw vertical dividers between tank densities
abline( v = 16.5, lwd = 0.5 )
abline( v = 32.5, lwd = 0.5 )
abline( v = 46.5, lwd = 0.5 )
text( 8, 0.6, 'Tiny ponds (5)')
text( 16 + 8, 0.6, 'Small ponds (10)')
text( 32 + 8, 0.6, 'Medium ponds (25)')
text( 46 + 8, 0.6, 'Large ponds (35)')
New data was fed to m13.3 model and generate new posterior estimations.

圖 55.5: New data was fed to m13.3 model and generate new posterior estimations.

55.3 使用多於一個類別作爲多層回歸的隨機變量 more than one type of cluster

我們當然可以在同一個模型中加入更多的 \((>1)\) 分層變量。例如我們在 Chapter 53.1.1 看到的黑猩猩社會學數據 data(chimpanzees)

data("chimpanzees")
d <- chimpanzees
str(d)
## 'data.frame':    504 obs. of  8 variables:
##  $ actor       : int  1 1 1 1 1 1 1 1 1 1 ...
##  $ recipient   : int  NA NA NA NA NA NA NA NA NA NA ...
##  $ condition   : int  0 0 0 0 0 0 0 0 0 0 ...
##  $ block       : int  1 1 1 1 1 1 2 2 2 2 ...
##  $ trial       : int  2 4 6 8 10 12 14 16 18 20 ...
##  $ prosoc_left : int  0 0 1 0 1 1 1 1 0 0 ...
##  $ chose_prosoc: int  1 0 0 1 1 1 0 0 1 1 ...
##  $ pulled_left : int  0 1 0 0 1 1 0 0 0 0 ...

這個數據裏,pulled_left 是從屬於每一頭黑猩猩個體的 (within a cluster of pulls belonging to an individual chimpanzee)。同時呢,這些拉動左側槓桿的行爲其實又是從屬於一個個實驗設計的 block 下的。這些 block 實際標記的是同一天進行的實驗。於是這裏出現了每個觀察數據的結果變量 pulled_left 既從屬於實驗對象 – 黑猩猩個體 (1 to 7),也從屬於實驗 block (1 to 6) 的現象。所以給黑猩猩個體和實驗 block 同時設置隨機截距也是沒有問題的。這裏我們利用這個特殊的數據來嘗試設計並運行含有兩個隨機截距結構的模型。這樣我們可以使用數據本身蘊含的信息充分學習應有的超參數用於我們已知的部分合併策略 partial pooling,從而提升各個參數的估計結果和效率,並且同時獲得不同的黑猩猩之間的方差,和不同的實驗 block 之間的方差。

55.3.1 黑猩猩數據的多層回歸模型 multilevel chimpanzees

我們可以直接利用 Chapter 53.1.1 一開始設定好的模型,增加 block 的隨機截距:

\[ \begin{aligned} L_i & \sim \text{Binomial}(1, p_i) \\ \text{logit}(p_i) & = \alpha_{\text{ACTOR}[i]} + \color{green}{\gamma_{\text{BLOCK}[i]}} + \beta_{\text{TREATMENT}[i]}\\ \beta_j & \sim \text{Normal}(0, 0.5) \;\;\;\; \text{for } j = 1,\dots, 4\\ \alpha_j & \sim \text{Normal}(\bar{\alpha}, \sigma_\alpha) \;\;\;\; \text{for } j = 1, \dots, 7 \\ \color{green}{\gamma_j } &\; \color{green}{\sim \text{Normal}(0, \sigma_\gamma)\;\;\;\;\; \text{for } j = 1, \dots, 6} \\ \bar{\alpha} & \sim \text{Normal}(0, 1.5) \\ \sigma_\alpha & \sim \text{Exponential} (1) \\ \color{green}{\sigma_\gamma} & \;\color{green}{ \sim \text{Exponential}(1)} \end{aligned} \]

從模型結構上,我們給不同的分層變量設置了自己的參數向量,對於每隻黑猩猩 actor,我們設定的參數向量是 \(\alpha\),它有7個元素,長度是 7,因爲一共有七隻黑猩猩;實驗的 block 有 6 個,所以它的參數向量長度是 6。這兩個分層變量需要有自己的方差(標準差)參數,也就是 \(\sigma_\alpha, \sigma_\gamma\)。要注意的一點是只能給一個總體平均值 \(\bar{\alpha}\) 給兩個隨機截距。下面的代碼就可以運行上述模型:

d <- d %>% 
  mutate(treatment = 1 + prosoc_left + 2*condition)
table(d$treatment)
## 
##   1   2   3   4 
## 126 126 126 126
dat_list <- list(
  pulled_left = d$pulled_left, 
  actor  = d$actor, 
  block_id = d$block, 
  treatment = as.integer(d$treatment)
)
set.seed(13) 

m13.4 <- ulam(
  alist(
    pulled_left  ~ dbinom( 1, p ) , 
    logit(p) <- a[actor] + g[block_id] + b[treatment], 
    b[treatment] ~ dnorm( 0, 0.5 ), 
    ## adaptive priors
    a[actor] ~ dnorm( a_bar,  sigma_a ), 
    g[block_id] ~ dnorm( 0, sigma_g ), 
    ## hyper-priors 
    a_bar ~ dnorm( 0, 1.5 ), 
    sigma_a ~ dexp(1), 
    sigma_g ~ dexp(1)
  ), data = dat_list, chains = 4, cores = 4, log_lik = TRUE
)
saveRDS(m13.4, "../Stanfits/m13_4.rds")
m13.4 <- readRDS("../Stanfits/m13_4.rds")
precis(m13.4, depth = 2)
##                 mean         sd         5.5%        94.5%     n_eff     Rhat4
## b[1]    -0.127619719 0.30371293 -0.612525116  0.351831861 403.94596 1.0109754
## b[2]     0.408238854 0.30232189 -0.053260952  0.901320472 358.18837 1.0087428
## b[3]    -0.472686499 0.29977648 -0.933074984  0.020725297 325.50478 1.0134452
## b[4]     0.287140550 0.30400142 -0.196915752  0.783449534 314.64426 1.0108916
## a[1]    -0.370040842 0.38111969 -0.956288690  0.250165297 359.95580 1.0120710
## a[2]     4.722208260 1.32508814  3.099312312  6.970585847 441.65568 1.0040650
## a[3]    -0.671017597 0.38087108 -1.286621692 -0.060970107 393.42474 1.0081857
## a[4]    -0.676761827 0.37681820 -1.281970759 -0.079834071 379.66040 1.0108697
## a[5]    -0.362885676 0.37315672 -0.958119173  0.248182788 388.40275 1.0096311
## a[6]     0.572373093 0.37008050 -0.019418355  1.173132021 413.40776 1.0063314
## a[7]     2.109826748 0.45840001  1.380573715  2.851828350 468.89658 1.0065407
## g[1]    -0.156800879 0.22875942 -0.566965710  0.094427404 321.47070 1.0083836
## g[2]     0.038277620 0.19474280 -0.222321524  0.363343586 947.73530 1.0013975
## g[3]     0.055900123 0.19581427 -0.188554858  0.364005505 764.31955 1.0000968
## g[4]     0.011329983 0.17640463 -0.245645460  0.291591789 900.16676 1.0016050
## g[5]    -0.033451700 0.19154418 -0.343860156  0.234539865 821.95378 1.0082556
## g[6]     0.110218541 0.20252666 -0.129571122  0.438377440 760.17843 1.0023356
## a_bar    0.633025847 0.73797175 -0.516031871  1.785298810 700.52766 1.0041607
## sigma_a  2.032605929 0.64825884  1.220010000  3.139888818 570.05539 1.0006796
## sigma_g  0.209456474 0.17425778  0.032948405  0.516423626 252.48113 1.0167476

首先,我們從 n_eff 可以看出各個參數的有效樣本量差別其實較大。這樣的現象在結構複雜的模型進行事後樣本採樣的過程中其實很常見。這可能會有許多不同的原因,其中之一是模型中可能有一個或者幾個在樣本採集時花了較多的時間在某個邊界值附近不停地採集樣本。這裏很顯然就是 sigma_g,它花了很多時間在它的起始值 0 附近不停地採集樣本,它的 Rhat 值也顯然大於 1。這些都是採樣效率低下的信號。

precis_plot( precis(m13.4, depth = 2) )
Posterior means and 89% compatibility intervals for m13.4. The greater variation across actors than blocks can be seen immediately in the a and g distributions

圖 55.6: Posterior means and 89% compatibility intervals for m13.4. The greater variation across actors than blocks can be seen immediately in the a and g distributions

其次,觀察 sigma_asigma_g 會很容易就發現不同黑猩猩之間的變化顯然比不同天進行實驗的變化要顯著的多。這一現象可以用圖 55.7 展示得更加清楚。

post <- extract.samples( m13.4 )
rethinking::dens( post$sigma_a ,
                  xlim = c(0, 4), 
                  ylim = c(0, 3.8), 
                  col = rangi2, 
                  bty = "n", 
                  lwd = 2, 
                  xlab = "standard deviation", 
                  ylab = "Density")
rethinking::dens( post$sigma_g , add =  TRUE, 
                  lwd = 2)

text( 0.8, 02.5, 'block')
text( 3, 0.5, 'actor', col = rangi2)
Posterior distributions of the standard deviations of varying intercepts by actor (blue), and block (black).

圖 55.7: Posterior distributions of the standard deviations of varying intercepts by actor (blue), and block (black).

這也就是說增加不同實驗 block 的隨機截距並沒有讓模型增加的過度擬合的風險。我們來比較一下只有一個黑猩猩隨機截距時的模型和 m13.4 之間的模型信息差別:

set.seed(14)
m13.5 <- ulam(
  alist(
    pulled_left ~ dbinom( 1, p ), 
    logit(p) <- a[actor] + b[treatment] , 
    b[treatment] ~ dnorm( 0, 0.5 ), 
    a[actor] ~ dnorm( a_bar, sigma_a ),
    a_bar ~ dnorm( 0, 1.5 ), 
    sigma_a ~ dexp(1)
  ), data = dat_list, chains = 4, cores = 4, log_lik = TRUE
)

saveRDS(m13.5, "../Stanfits/m13_5.rds")
m13.5 <- readRDS("../Stanfits/m13_5.rds")
compare(m13.4, m13.5)
##            WAIC        SE      dWAIC       dSE      pWAIC     weight
## m13.5 531.37616 19.126297 0.00000000        NA  8.5921361 0.60655949
## m13.4 532.24191 19.421569 0.86574585 1.6294066 10.6837879 0.39344051

m13.4m13.5 兩個模型之間的比較結果來看,即便 m13.4 中多增加了 7 個未知參數,但是 pWAIC 的比較,也就是實際有效參數個數之間的差只有 2。這主要是因爲 block 的事後分佈的方差其實十分接近 0,所以表示這個 block 部分的隨機截距部分增加的參數的實際結果都接近 0。我們的多層回歸模型雖然可以做到增加實驗 block 的隨機截距,但是增加這個隨機截距對模型並沒有顯著的改善,可以說沒有太多幫助。

55.4 分散轉換與非中心型先驗概率 divergent transition and non-centered priors

使用並運行多層回歸模型時,Stan 經常可能送給你一個莫名其妙的警告,類似:

There were 15 divergent transitions after warmup.

具體原因可能有很多,有主要兩種方式來克服這個警告。第一種是使用更多的 burn-in 或者叫做 warm-up,並且調整 Stan 裏的設置採樣跳躍幅度的變量 adapt_delta,它默認是 0.8,把它改成0.9以上(只能是小於1的數值)的數字之後跳躍採集樣本的幅度會縮小一些,從而改善事後樣本採集的代表性,一定程度上可以避免看見上述警告。但是有些多層回歸模型不論你怎麼調整這個跳躍幅度,增加採樣的 burn-in 過程,它始終都無法給出合適的事後樣本分佈。這時候需要使用的技巧是重新改寫你的模型。很多統計學模型你可以轉換思路用別的方式來表達在數學上涵義相同的模型。這個方法又被叫做再參數化 (reparameterize)。

下面是兩個簡單的實例。

55.4.1 魔鬼的漏斗 the devil’s funnel

我們不需要用複雜的模型就能體驗到 Stan 給出的分散轉換 divergent transition 警告。假如有兩個簡單的變量 \(v, x\) 他們之間的關係是:

\[ \begin{aligned} v & \sim \text{Normal}(0, 3) \\ x & \sim \text{Normal}(0, \exp(v)) \end{aligned} \]

沒有特別的數據,只有這樣兩個互相有聯繫的聯合分佈需要我們嘗試去採集樣本。這是典型的多層回歸結構模型,因爲變量 \(x\) 的方差由另一個變量 \(v\) 來決定。這個模型的運行程序如下:

m13.7 <- ulam(
  alist(
    v ~ normal(0, 3), 
    x ~ normal(0, exp(v))
  ), data = list(N = 1), chains = 4
)
saveRDS(m13.7, "../Stanfits/m13_7.rds")

你會很顯然看見一連串的警告,叫你去看這個看那個求助啥的:

There were 78 divergent transitions after warmup. See
http://mc-stan.org/misc/warnings.html#divergent-transitions-after-warmup
to find out why this is a problem and how to eliminate them.There were 2 transitions after warmup that exceeded the maximum treedepth. Increase max_treedepth above 10. See
http://mc-stan.org/misc/warnings.html#maximum-treedepth-exceededThere were 2 chains where the estimated Bayesian Fraction of Missing Information was low. See
http://mc-stan.org/misc/warnings.html#bfmi-lowExamine the pairs() plot to diagnose sampling problems
The largest R-hat is 1.15, indicating chains have not mixed.
Running the chains for more iterations may help. See
http://mc-stan.org/misc/warnings.html#r-hatBulk Effective Samples Size (ESS) is too low, indicating posterior means and medians may be unreliable.
Running the chains for more iterations may help. See
http://mc-stan.org/misc/warnings.html#bulk-essTail Effective Samples Size (ESS) is too low, indicating posterior variances and tail quantiles may be unreliable.
Running the chains for more iterations may help. See
http://mc-stan.org/misc/warnings.html#tail-ess

這個只有兩個參數需要估計的模型運行給出的事後概率分佈也十分地糟糕,上面的警告中給出了相當多的分散轉換 (divergent transitions) ,下面的模型運行結果總結也給出了特別差勁的 n_eff, Rhat

m13.7 <- readRDS("../Stanfits/m13_7.rds")
precis(m13.7)
##          mean          sd         5.5%      94.5%     n_eff     Rhat4
## v   1.7154855   2.3068647   -1.4054807  5.9914017 17.287243 1.1557813
## x -41.8446286 311.2534143 -271.8156635 47.4611106 92.111681 1.0141386

看一下它可憐的採樣軌跡圖 trace plot :

traceplot(m13.7)
## [1] 1000
## [1] 1
## [1] 1000
`traceplot(m13.7)`

圖 55.8: traceplot(m13.7)

我們可以簡單地通過修改模型的構建模式來克服這個問題。因爲變量 \(x\) 的方差取決於 \(v\)

\[ x \sim \text{Normal}(0, \exp(v)) \]

變量 \(v\) 決定了 \(x\) 的方差大小,上面的這種模型結構被叫做參數中心化 (centered parameterization),其涵義就是一個參數的分佈由另一個參數或者多個參數來決定。參數中心化之外的另一種選擇是參數非中心化 (non-centered parameterization)。這個非中心化就是把參數之間的依賴關係保留,但是在寫成模型的時候儘量避免在指定分佈的那行中加入兩個參數。例如可以把 m13.7 的表達式改寫成:

\[ \begin{aligned} v & \sim \text{Normal}(0, 3) \\ z & \sim \text{Normal}(0, 1) \\ x & = z \exp(v) \end{aligned} \]

很多人可能一開始不理解爲什麼要這樣寫。但是仔細想想應該不難理解,這其實是我們平時在把觀察值標準化的一個逆向過程。我們在把某個變量標準化的過程是怎樣的?通常是把它減去自己的平均值,然後除以自己的標準差。新產生的變量就是一個均值爲0,標準差是1的標準正(常)態分佈。也就是說,上面的表達式裏,我們通過 \(z\),一個標準正(常)態分佈變量,把 \(x\)\(v\) 之間的關係串聯起來。\(x\) 本身的均值是零,它除以自己的標準差 \(\frac{x}{\exp(v)}\) 就成爲了一個標準正(常)態的變量 \(z\)。經過這一番等價轉換之後模型變得可以順利在 Stan 裏被運行和採樣了。

m13.7nc <- ulam(
  alist(
    v ~ normal(0, 3), 
    z ~ normal(0, 1),
    gq> real[1]: x <<- z*exp(v)
  ), data = list(N = 1), chains = 4
)
saveRDS(m13.7nc, "../Stanfits/m13_7nc.rds")

整個世界恢復了安靜。你看模型運行的結果也是正常的了:

m13.7nc <- readRDS("../Stanfits/m13_7nc.rds")
precis(m13.7nc)
##            mean           sd        5.5%      94.5%     n_eff     Rhat4
## v -0.0287545661   2.95017093  -4.6421334  4.6942797 1267.2926 1.0028912
## z -0.0093817719   0.97405443  -1.5831453  1.4929941 1459.2957 1.0004412
## x  2.5224882198 197.04808870 -22.7731384 20.7557186 1787.5479 1.0006452
traceplot(m13.7nc)
## [1] 1000
## [1] 1
## [1] 1000
`traceplot(m13.7nc)`

圖 55.9: traceplot(m13.7nc)

如果我們把此時採樣成功的 \(x\)\(v\) 之間繪製散點圖,你就會直觀的看見這個像魔鬼一樣的漏斗的真實形狀:

dat.sam <- extract.samples(m13.7nc)
plot(dat.sam$x, 
     dat.sam$v, 
     bty = "n", 
     xlab = "x",
     ylab = "v")
The devil's funnel.

圖 55.10: The devil’s funnel.

我們成功地對這樣的近乎畸形的變量實施了轉換數據之後的事後樣本採集。

55.4.2 參數非中心化的黑猩猩數據

接下來我們來試圖解決黑猩猩數據中使用多層回歸模型時出現的分散轉移 divergent transition 問題。當時我們的 m13.4 試圖給 block 增加隨機效應,當時在設定參數的先驗概率分佈時,設定了兩個參數在相同的行裏,他們也是導致模型運行報警的原因。現在我們可以來試着解決它。

\[ \begin{aligned} L_i & \sim \text{Binomial}(1, p_i) \\ \text{logit}(p_i) & = \alpha_{\text{ACTOR}[i]} + \color{green}{\gamma_{\text{BLOCK}[i]}} + \beta_{\text{TREATMENT}[i]}\\ \beta_j & \sim \text{Normal}(0, 0.5) \;\;\;\; \text{for } j = 1,\dots, 4\\ \alpha_j & \sim \text{Normal}(\bar{\alpha}, \sigma_\alpha) \;\;\;\; \text{for } j = 1, \dots, 7 \\ \color{green}{\gamma_j } &\; \color{green}{\sim \text{Normal}(0, \sigma_\gamma)\;\;\;\;\; \text{for } j = 1, \dots, 6} \\ \bar{\alpha} & \sim \text{Normal}(0, 1.5) \\ \sigma_\alpha & \sim \text{Exponential} (1) \\ \color{green}{\sigma_\gamma} & \;\color{green}{ \sim \text{Exponential}(1)} \end{aligned} \]

在對模型進行重新參數化之前,我們可以先試着在 Stan 內部嘗試調整 adapt_delta ,它原本默認的大小是 0.95:

set.seed(2020)
m13.4b <- ulam(m13.4, chains = 4, 
               cores = 4, 
               control = list(adapt_delta = 0.99))
saveRDS(m13.4b, "../Stanfits/m13_4b.rds")
m13.4b <- readRDS("../Stanfits/m13_4b.rds")
divergent(m13.4b)
## [1] 5

可見修改這個 adapt_delta 也沒有辦法提升太多,它依然在報錯。當然偶爾也能真的解決問題,實在是在看你的運氣。而且很多時候,即使它不再通過電腦系統警告,它實際採集的事後樣本也是十分低效的。你可以觀察 precis(m13.4b) 給出的 n_eff,也就是有效樣本量其實很多都還是小於500的。實際使用4条獨立採集鏈每條500個獨立樣本的總樣本量應該在2000左右。

precis(m13.4b, depth = 2)
##                  mean         sd         5.5%        94.5%      n_eff     Rhat4
## b[1]    -0.1572102900 0.30678784 -0.639271495  0.343564769  402.50898 1.0133566
## b[2]     0.3760131051 0.29844899 -0.097009535  0.852948659  438.79219 1.0067294
## b[3]    -0.4949120780 0.30808988 -1.001918025 -0.017546887  418.26237 1.0100785
## b[4]     0.2581843386 0.30676503 -0.246029052  0.754586625  414.13612 1.0107870
## a[1]    -0.3475363401 0.36432862 -0.926144369  0.200724744  437.14385 1.0091803
## a[2]     4.5946727777 1.21993398  3.006558784  6.800638368  619.43260 1.0103146
## a[3]    -0.6335415787 0.36505148 -1.198982995 -0.042148804  389.82795 1.0122092
## a[4]    -0.6393672456 0.36432523 -1.227955826 -0.078705404  464.07412 1.0089550
## a[5]    -0.3338642974 0.37663905 -0.946543034  0.275166346  473.68551 1.0118031
## a[6]     0.5991703820 0.37498622  0.026426968  1.186117279  457.28046 1.0091865
## a[7]     2.1471699790 0.45691772  1.450735749  2.914264702  533.10034 1.0133078
## g[1]    -0.1428730829 0.20044449 -0.525017505  0.070405136  471.73005 1.0111776
## g[2]     0.0387465222 0.15132335 -0.168327269  0.306505379  682.87717 1.0038622
## g[3]     0.0505496597 0.15814008 -0.155178555  0.328932444  767.76580 1.0058225
## g[4]     0.0089953387 0.15486352 -0.230211915  0.257745757 1082.25766 1.0003697
## g[5]    -0.0225651857 0.16089717 -0.292705183  0.227528171 1291.59183 1.0018525
## g[6]     0.0996398014 0.18231717 -0.116919531  0.447152911  506.41514 1.0121506
## a_bar    0.5884998301 0.74455413 -0.641517611  1.727670676  572.43485 1.0061693
## sigma_a  1.9687830785 0.63396246  1.175349730  3.176123372  897.43947 1.0017410
## sigma_g  0.1856985870 0.16517592  0.012759445  0.499172450  233.36486 1.0295059

如果通過修改參數的形式使之去中心化,則能夠大大改善模型的運行。這個模型裏需要修改參數化的主要是這兩行:

\[ \begin{aligned} \alpha_j & \sim \text{Normal}(\bar\alpha, \sigma_\alpha) & \text{[intercepts for actors]}\\ \gamma_j & \sim \text{Normal}(0, \sigma_\gamma)& \text{[intercepts for blocks]} \end{aligned} \]

這裏面其實有三個“中心化”的參數:\(\bar\alpha, \sigma_\alpha, \sigma_\gamma\)。使用類似 m13.7nc 的方法,我們需要爲他們設定標準化的替代參數:

\[ \begin{aligned} L_i & \sim \text{Binomial}(1, p_i) \\ \text{logit}(p_i) & = \color{lightblue}{\underbrace{\bar\alpha + z_{\text{ACTOR[i]}}\sigma_\alpha}_{\alpha_\text{ACTOR[i]}}} + \color{lightblue}{\underbrace{x_{\text{BLOCK}[i]}\sigma_\gamma}_{\gamma_\text{BLOCK[i]}}} + \beta_{\text{TREATMENT}[i]} \\ \beta_j & \sim \text{Normal}(0, 0.5), \text{ for } j = 1,\dots,4 \\ \color{lightblue}{z_j} & \color{lightblue}{\; \sim \text{Normal}(0,1)} & \text{[Standardized actor intercepts]} \\ \color{lightblue}{x_j} & \color{lightblue}{\; \sim \text{Normal}(0,1)} & \text{[Standardized block intercepts]} \\ \bar{\alpha} & \sim \text{Normal}(0, 1.5) \\ \sigma_\alpha & \sim \text{Exponential}(1) \\ \sigma_\gamma & \sim \text{Exponential}(1) \end{aligned} \]

不難發現經過修改厚的模型中向量 \(z\) 提供了標準化的 actor 隨機截距,\(x\) 提供了標準化的 block 隨機截距。每頭大猩猩 actor 的隨機截距實際被定義爲:

\[ \alpha_j = \bar\alpha + z_j \sigma_\alpha \]

每個實驗區塊 block 的隨機截距被定義爲:

\[ \gamma_j = x_j\sigma_\gamma \]

現在我們來運行這個被重新改寫過的 m13.4 模型:

set.seed(13)
m13.4nc <- ulam(
  alist(
    pulled_left ~ dbinom(1, p), 
    logit(p) <-  a_bar + z[actor]*sigma_a +    # actor intercepts
                x[block_id]*sigma_g +        # block intercepts
                b[treatment] ,
    b[treatment] ~ dnorm( 0, 0.5 ), 
    z[actor] ~ dnorm( 0, 1 ),
    x[block_id] ~ dnorm( 0, 1 ), 
    a_bar ~ dnorm( 0, 1.5 ), 
    sigma_a ~ dexp(1), 
    sigma_g ~ dexp(1), 
    gq> vector[actor]: a <<- a_bar + z*sigma_a, 
    gq> vector[block_id]: g<<- x*sigma_g
  ), data = dat_list, chains = 4, cores = 4
)

saveRDS(m13.4nc, "../Stanfits/m13_4nc.rds")

m13.4ncn_eff 顯然比 m13.4 改善很多,而且也沒有報錯:

m13.4nc <- readRDS("../Stanfits/m13_4nc.rds")
precis(m13.4nc, depth = 2)
##                  mean         sd          5.5%         94.5%      n_eff      Rhat4
## b[1]    -0.1373449475 0.30100329 -0.6152752117  0.3434728833 1109.20149 1.00595881
## b[2]     0.3888674543 0.31027407 -0.0889388119  0.8914549736 1126.93533 1.00365251
## b[3]    -0.4870408411 0.30935145 -0.9839392550  0.0070806699 1149.30896 1.00403385
## b[4]     0.2675311406 0.29608590 -0.2053236335  0.7265043819 1139.59722 1.00422956
## z[1]    -0.5139834523 0.39675770 -1.1678480395  0.0977745452  570.20117 1.01007561
## z[2]     2.1279913179 0.64805501  1.1422861675  3.2224402595  834.76934 1.00325956
## z[3]    -0.6837229179 0.41584942 -1.3577540963 -0.0245178170  569.57612 1.00789383
## z[4]    -0.6869211727 0.40944901 -1.3445954812 -0.0236560772  568.85572 1.00689660
## z[5]    -0.5172173982 0.39191673 -1.1403129550  0.1034563609  590.52189 1.00843504
## z[6]    -0.0016439012 0.36846295 -0.5776370019  0.5809683677  556.94923 1.00899136
## z[7]     0.8286856465 0.45469306  0.1351884604  1.5921906220  577.79040 1.00723877
## x[1]    -0.6899484779 0.92017834 -2.1405682214  0.7927252571 1879.62650 1.00020587
## x[2]     0.1683882817 0.82967260 -1.2074554330  1.4459472568 1988.35621 0.99902014
## x[3]     0.1883825661 0.88729551 -1.3445548516  1.5454719722 1599.57387 0.99934747
## x[4]     0.0736859911 0.87573225 -1.3284010857  1.5004526253 1872.59855 0.99969896
## x[5]    -0.1579854442 0.86476793 -1.5691453644  1.1965664548 1965.01018 1.00022145
## x[6]     0.4404814966 0.83099134 -0.9035996548  1.7467719731 2034.08888 1.00044538
## a_bar    0.5966686669 0.72455365 -0.5512172079  1.7547745106  457.50619 1.01001900
## sigma_a  1.9868039280 0.62475020  1.1935664496  3.0806919070  548.75222 1.01023381
## sigma_g  0.2055282086 0.16780483  0.0170035585  0.5116821942 1035.04896 1.00196218
## g[1]    -0.1659934915 0.21774107 -0.5702435381  0.0761933321 1360.43058 1.00156195
## g[2]     0.0426012678 0.17846856 -0.2168633003  0.3449450507 2239.18753 1.00002541
## g[3]     0.0505519252 0.17811492 -0.1941666909  0.3582976624 1811.46328 0.99948636
## g[4]     0.0143612958 0.17667352 -0.2641923950  0.2970075014 1916.45293 0.99928424
## g[5]    -0.0349256905 0.17496328 -0.3471585015  0.2205464443 1695.10486 1.00045918
## g[6]     0.1072076540 0.19605102 -0.1279119869  0.4622580054 1506.23216 1.00111282
## a[1]    -0.3404283451 0.36533335 -0.9132766082  0.2629761164 1281.86213 1.00166555
## a[2]     4.6517792446 1.26372873  3.0195252176  6.7913091202 1291.08110 1.00269334
## a[3]    -0.6469904859 0.36686590 -1.2332357909 -0.0507023244 1220.56241 1.00349456
## a[4]    -0.6567906671 0.37266026 -1.2479210814 -0.0427601862 1359.97407 1.00534633
## a[5]    -0.3469747466 0.37267807 -0.9282008426  0.2647662951 1230.41847 1.00636946
## a[6]     0.5924776232 0.37474838  0.0034196068  1.2016656784 1213.69242 1.00800165
## a[7]     2.1128718972 0.47064296  1.3943355597  2.8918979665 1700.68726 1.00210149

用圖形來比較 m13.4m13.4nc 二者之間的 n_eff 更加直觀:

precis_c <- precis( m13.4, depth = 2 )
precis_nc <- precis( m13.4nc, depth = 2 )
pars <- c( paste("a[", 1:7, "]", sep = ""), 
           paste("g[", 1:6, "]", sep = ""), 
           paste("b[", 1:4, "]", sep = ""), 
           "a_bar", "sigma_a", "sigma_g")

neff_table <- cbind( precis_c[pars, "n_eff"], 
                     precis_nc[pars, "n_eff"])

plot( neff_table, 
      xlim = range(neff_table), 
      ylim = range(neff_table), 
      xlab = "n_eff (centered)", 
      ylab = "n_eff (non-centered)", 
      lwd = 2, 
      bty = "n")
abline(a = 0, b = 1, lty = 2)
Comparing the centered (horizontal) and non-centered (vertical) parameerizations of the multilevel chimpanzees model, m13.4. Each point is a parameter. All but two parameters lie above the diagonal, indicating better sampling for the non-centered parameterization.

圖 55.11: Comparing the centered (horizontal) and non-centered (vertical) parameerizations of the multilevel chimpanzees model, m13.4. Each point is a parameter. All but two parameters lie above the diagonal, indicating better sampling for the non-centered parameterization.