Thomas Bayes

泊松回歸模型的貝葉斯Stan實現

分析目的,數據,和選擇 Poisson 回歸模型的原因

我們這裏使用之前擬合貝葉斯邏輯回歸模型時使用的相同的數據來展示如何跑貝葉斯泊松回歸模型。

d <- read.table("https://raw.githubusercontent.com/winterwang/RStanBook/master/chap05/input/data-attendance-2.txt", sep = ",", header = T)
head(d)
##   PersonID A Score  M  Y
## 1        1 0    69 43 38
## 2        2 1   145 56 40
## 3        3 0   125 32 24
## 4        4 1    86 45 33
## 5        5 1   158 33 23
## 6        6 0   133 61 60

其中,

  • PersonID: 是學生的編號;
  • A, Score: 用來預測出勤率的兩個預測變量,分別是表示是否喜歡打工的 A,和表示對學習本身是否喜歡的評分 (滿分200);
  • M: 過去三個月內,該名學生一共需要上課的總課時數;
  • Y: 過去三個月內,該名學生實際上出勤的課時數。

這一次我們希望通過分析泊松回歸來回答「AScore 對總課時數 M 具體有多大的影響?」這個問題。之前擬合貝葉斯邏輯回歸模型時,使用的結果變量是 Y,也就是實際出勤課時數。但是本小節我們用 M 作爲結果變量。因爲總課時數是學生自己選課時的結果,也就是說學生本身的態度(是否喜歡打工,是否熱愛學習),可能本身左右了他/她到底會選多少課。背景知識假設是:喜歡多去打工的學生,選課可能態度消極,總課時數從開始可能就選的少。那麼像總選課時數這樣的非負(計數型)離散變量作爲結果變量的時候,泊松回歸模型是我們的第一選擇。

想象模型機制

如果使用上上節介紹的多重線性回歸模型,那麼模型的預測變量的分佈便可能取到負數,這樣就不符合實際情況下“總選課時數”是非負(計數型)離散變量這一事實。這就需要把預測變量 AScore 相加的線性模型 \((b_1 + b_2A + b_3Score)\),通過數學轉換限制在非負數範圍。假設平均總課時數是 \(\lambda\),我們認爲它服從均值是 \(\lambda\) 的泊松分佈。關於泊松分佈的詳細知識,期望值和方差的推導可以參考學習筆記。另外,非貝葉斯版本的一般性傳統泊松回歸模型可以參照學習筆記的廣義線性回歸的泊松回歸模型章節

對泊松回歸模型略有瞭解的話應該很自然地想到,把結果變量限制在非負數範圍的標準鏈接方程是 \(\log(\lambda)\),或者在 Stan 模型中,我們更自然地把線性模型部分寫在指數模型中: \(\exp(b_1 + b_2A + b_3Score)\)

寫下數學模型表達式

\[ \begin{aligned} \lambda[n] & = \exp(b_1 + b_2A[n] + b_3Score[n]) & n = 1, \dots, N \\ M[n] & \sim \text{Poisson}(\lambda[n]) & n = 1, \dots, N \end{aligned} \]

其中,

  • \(N\),是該數據中學生的人數;
  • \(n\),是每名學生的標籤/編號(下標);
  • \(b_1, b_2, b_3\) 是我們感興趣的參數。

把數學模型翻譯成 Stan 模型代碼

data {
  int N; 
  int<lower=0, upper=1> A[N]; 
  real<lower=0, upper=1> Score[N]; 
  int<lower=0> M[N];
}

parameters {
  real b[3]; 
}

transformed parameters {
  real lambda[N];
  for (n in 1:N) {
    lambda[n] = exp(b[1] + b[2]*A[n] + b[3]*Score[n]);
  }
}

model {
  for (n in 1:N) {
    M[n] ~ poisson(lambda[n]); 
  }
}

generated quantities {
  int m_pred[N]; 
  for (n in 1:N) {
    m_pred[n] = poisson_rng(M[n], q[n]);
  }
}

值得一提的是,在 Stan 中,提供了 poisson_log(x) 分佈函數,其實它等價於使用 poisson(exp(x))。除了更加接近我們熟悉的泊松回歸模型的數學表達式,避免了 exp 指數運算,計算結果穩定。於是我們還可以把上面的模型修改成:

data {
  int N; 
  int<lower=0, upper=1> A[N]; 
  real<lower=0, upper=1> Score[N]; 
  int<lower=0> M[N];
}

parameters {
  real b[3]; 
}

transformed parameters {
  real lambda[N];
  for (n in 1:N) {
    lambda[n] = b[1] + b[2]*A[n] + b[3]*Score[n];
  }
}

model {
  for (n in 1:N) {
    M[n] ~ poisson_log(lambda[n]); 
  }
}

generated quantities {
  int m_pred[N]; 
  for (n in 1:N) {
    m_pred[n] = poisson_log_rng(M[n], q[n]);
  }
}

運行它的代碼如下:

library(rstan)
## Loading required package: StanHeaders
## Loading required package: ggplot2
## rstan (Version 2.19.2, GitRev: 2e1f913d3ca3)
## For execution on a local, multicore CPU with excess RAM we recommend calling
## options(mc.cores = parallel::detectCores()).
## To avoid recompilation of unchanged Stan programs, we recommend calling
## rstan_options(auto_write = TRUE)
data <- list(N=nrow(d), A=d$A, Score=d$Score/200, M=d$M)
# fit <- stan(file='model/model5-6.stan', data=data, seed=1234)
fit <- stan(file='stanfiles/model5-6b.stan', data=data, seed=1234, pars = c("b", "lambda"))
## 
## SAMPLING FOR MODEL 'model5-6b' NOW (CHAIN 1).
## Chain 1: 
## Chain 1: Gradient evaluation took 4.3e-05 seconds
## Chain 1: 1000 transitions using 10 leapfrog steps per transition would take 0.43 seconds.
## Chain 1: Adjust your expectations accordingly!
## Chain 1: 
## Chain 1: 
## Chain 1: Iteration:    1 / 2000 [  0%]  (Warmup)
## Chain 1: Iteration:  200 / 2000 [ 10%]  (Warmup)
## Chain 1: Iteration:  400 / 2000 [ 20%]  (Warmup)
## Chain 1: Iteration:  600 / 2000 [ 30%]  (Warmup)
## Chain 1: Iteration:  800 / 2000 [ 40%]  (Warmup)
## Chain 1: Iteration: 1000 / 2000 [ 50%]  (Warmup)
## Chain 1: Iteration: 1001 / 2000 [ 50%]  (Sampling)
## Chain 1: Iteration: 1200 / 2000 [ 60%]  (Sampling)
## Chain 1: Iteration: 1400 / 2000 [ 70%]  (Sampling)
## Chain 1: Iteration: 1600 / 2000 [ 80%]  (Sampling)
## Chain 1: Iteration: 1800 / 2000 [ 90%]  (Sampling)
## Chain 1: Iteration: 2000 / 2000 [100%]  (Sampling)
## Chain 1: 
## Chain 1:  Elapsed Time: 0.128123 seconds (Warm-up)
## Chain 1:                0.123794 seconds (Sampling)
## Chain 1:                0.251917 seconds (Total)
## Chain 1: 
## 
## SAMPLING FOR MODEL 'model5-6b' NOW (CHAIN 2).
## Chain 2: 
## Chain 2: Gradient evaluation took 8e-06 seconds
## Chain 2: 1000 transitions using 10 leapfrog steps per transition would take 0.08 seconds.
## Chain 2: Adjust your expectations accordingly!
## Chain 2: 
## Chain 2: 
## Chain 2: Iteration:    1 / 2000 [  0%]  (Warmup)
## Chain 2: Iteration:  200 / 2000 [ 10%]  (Warmup)
## Chain 2: Iteration:  400 / 2000 [ 20%]  (Warmup)
## Chain 2: Iteration:  600 / 2000 [ 30%]  (Warmup)
## Chain 2: Iteration:  800 / 2000 [ 40%]  (Warmup)
## Chain 2: Iteration: 1000 / 2000 [ 50%]  (Warmup)
## Chain 2: Iteration: 1001 / 2000 [ 50%]  (Sampling)
## Chain 2: Iteration: 1200 / 2000 [ 60%]  (Sampling)
## Chain 2: Iteration: 1400 / 2000 [ 70%]  (Sampling)
## Chain 2: Iteration: 1600 / 2000 [ 80%]  (Sampling)
## Chain 2: Iteration: 1800 / 2000 [ 90%]  (Sampling)
## Chain 2: Iteration: 2000 / 2000 [100%]  (Sampling)
## Chain 2: 
## Chain 2:  Elapsed Time: 0.121585 seconds (Warm-up)
## Chain 2:                0.130576 seconds (Sampling)
## Chain 2:                0.252161 seconds (Total)
## Chain 2: 
## 
## SAMPLING FOR MODEL 'model5-6b' NOW (CHAIN 3).
## Chain 3: 
## Chain 3: Gradient evaluation took 7e-06 seconds
## Chain 3: 1000 transitions using 10 leapfrog steps per transition would take 0.07 seconds.
## Chain 3: Adjust your expectations accordingly!
## Chain 3: 
## Chain 3: 
## Chain 3: Iteration:    1 / 2000 [  0%]  (Warmup)
## Chain 3: Iteration:  200 / 2000 [ 10%]  (Warmup)
## Chain 3: Iteration:  400 / 2000 [ 20%]  (Warmup)
## Chain 3: Iteration:  600 / 2000 [ 30%]  (Warmup)
## Chain 3: Iteration:  800 / 2000 [ 40%]  (Warmup)
## Chain 3: Iteration: 1000 / 2000 [ 50%]  (Warmup)
## Chain 3: Iteration: 1001 / 2000 [ 50%]  (Sampling)
## Chain 3: Iteration: 1200 / 2000 [ 60%]  (Sampling)
## Chain 3: Iteration: 1400 / 2000 [ 70%]  (Sampling)
## Chain 3: Iteration: 1600 / 2000 [ 80%]  (Sampling)
## Chain 3: Iteration: 1800 / 2000 [ 90%]  (Sampling)
## Chain 3: Iteration: 2000 / 2000 [100%]  (Sampling)
## Chain 3: 
## Chain 3:  Elapsed Time: 0.117803 seconds (Warm-up)
## Chain 3:                0.123413 seconds (Sampling)
## Chain 3:                0.241216 seconds (Total)
## Chain 3: 
## 
## SAMPLING FOR MODEL 'model5-6b' NOW (CHAIN 4).
## Chain 4: 
## Chain 4: Gradient evaluation took 7e-06 seconds
## Chain 4: 1000 transitions using 10 leapfrog steps per transition would take 0.07 seconds.
## Chain 4: Adjust your expectations accordingly!
## Chain 4: 
## Chain 4: 
## Chain 4: Iteration:    1 / 2000 [  0%]  (Warmup)
## Chain 4: Iteration:  200 / 2000 [ 10%]  (Warmup)
## Chain 4: Iteration:  400 / 2000 [ 20%]  (Warmup)
## Chain 4: Iteration:  600 / 2000 [ 30%]  (Warmup)
## Chain 4: Iteration:  800 / 2000 [ 40%]  (Warmup)
## Chain 4: Iteration: 1000 / 2000 [ 50%]  (Warmup)
## Chain 4: Iteration: 1001 / 2000 [ 50%]  (Sampling)
## Chain 4: Iteration: 1200 / 2000 [ 60%]  (Sampling)
## Chain 4: Iteration: 1400 / 2000 [ 70%]  (Sampling)
## Chain 4: Iteration: 1600 / 2000 [ 80%]  (Sampling)
## Chain 4: Iteration: 1800 / 2000 [ 90%]  (Sampling)
## Chain 4: Iteration: 2000 / 2000 [100%]  (Sampling)
## Chain 4: 
## Chain 4:  Elapsed Time: 0.112146 seconds (Warm-up)
## Chain 4:                0.136731 seconds (Sampling)
## Chain 4:                0.248877 seconds (Total)
## Chain 4:
fit
## Inference for Stan model: model5-6b.
## 4 chains, each with iter=2000; warmup=1000; thin=1; 
## post-warmup draws per chain=1000, total post-warmup draws=4000.
## 
##               mean se_mean   sd    2.5%     25%     50%     75%   97.5%
## b[1]          3.58    0.00 0.09    3.38    3.51    3.58    3.64    3.76
## b[2]          0.26    0.00 0.04    0.18    0.24    0.26    0.29    0.35
## b[3]          0.29    0.00 0.15    0.00    0.20    0.29    0.39    0.59
## lambda[1]     3.68    0.00 0.05    3.58    3.65    3.68    3.71    3.77
## lambda[2]     4.05    0.00 0.03    3.98    4.03    4.05    4.08    4.12
## lambda[3]     3.76    0.00 0.03    3.70    3.74    3.76    3.78    3.81
## lambda[4]     3.97    0.00 0.04    3.88    3.94    3.97    3.99    4.04
## lambda[5]     4.07    0.00 0.04    3.99    4.04    4.07    4.10    4.15
## lambda[6]     3.77    0.00 0.03    3.71    3.75    3.77    3.79    3.83
## lambda[7]     3.74    0.00 0.03    3.68    3.72    3.74    3.76    3.80
## lambda[8]     4.05    0.00 0.04    3.99    4.03    4.05    4.08    4.13
## lambda[9]     3.79    0.00 0.03    3.72    3.77    3.79    3.81    3.86
## lambda[10]    3.79    0.00 0.03    3.72    3.77    3.79    3.81    3.85
## lambda[11]    4.05    0.00 0.03    3.98    4.02    4.05    4.07    4.11
## lambda[12]    3.78    0.00 0.03    3.72    3.76    3.78    3.80    3.84
## lambda[13]    4.01    0.00 0.03    3.95    3.99    4.01    4.03    4.07
## lambda[14]    3.74    0.00 0.03    3.68    3.72    3.74    3.76    3.80
## lambda[15]    3.74    0.00 0.03    3.68    3.72    3.74    3.76    3.79
## lambda[16]    3.98    0.00 0.04    3.91    3.96    3.98    4.01    4.05
## lambda[17]    3.74    0.00 0.03    3.69    3.72    3.74    3.76    3.80
## lambda[18]    3.70    0.00 0.04    3.61    3.67    3.70    3.72    3.78
## lambda[19]    3.85    0.00 0.05    3.74    3.81    3.85    3.88    3.96
## lambda[20]    4.07    0.00 0.04    3.99    4.04    4.07    4.09    4.15
## lambda[21]    3.97    0.00 0.04    3.88    3.94    3.97    3.99    4.04
## lambda[22]    4.00    0.00 0.03    3.93    3.97    4.00    4.02    4.06
## lambda[23]    3.99    0.00 0.03    3.93    3.97    4.00    4.02    4.06
## lambda[24]    4.05    0.00 0.03    3.98    4.03    4.05    4.07    4.12
## lambda[25]    4.01    0.00 0.03    3.95    3.99    4.01    4.03    4.07
## lambda[26]    4.03    0.00 0.03    3.96    4.01    4.03    4.05    4.09
## lambda[27]    3.75    0.00 0.03    3.70    3.73    3.75    3.77    3.81
## lambda[28]    3.75    0.00 0.03    3.70    3.73    3.75    3.77    3.81
## lambda[29]    3.81    0.00 0.04    3.73    3.78    3.81    3.84    3.89
## lambda[30]    3.74    0.00 0.03    3.68    3.72    3.74    3.76    3.80
## lambda[31]    4.08    0.00 0.04    4.00    4.05    4.08    4.11    4.17
## lambda[32]    3.95    0.00 0.05    3.85    3.91    3.95    3.98    4.04
## lambda[33]    4.04    0.00 0.03    3.98    4.02    4.04    4.06    4.11
## lambda[34]    4.02    0.00 0.03    3.96    4.00    4.02    4.04    4.08
## lambda[35]    3.76    0.00 0.03    3.71    3.74    3.76    3.78    3.82
## lambda[36]    3.77    0.00 0.03    3.71    3.75    3.77    3.79    3.83
## lambda[37]    3.99    0.00 0.03    3.93    3.97    3.99    4.02    4.06
## lambda[38]    3.74    0.00 0.03    3.68    3.72    3.74    3.76    3.79
## lambda[39]    3.71    0.00 0.04    3.63    3.68    3.71    3.73    3.78
## lambda[40]    3.71    0.00 0.04    3.63    3.68    3.71    3.73    3.78
## lambda[41]    3.77    0.00 0.03    3.71    3.75    3.77    3.78    3.82
## lambda[42]    3.77    0.00 0.03    3.71    3.75    3.77    3.79    3.83
## lambda[43]    3.75    0.00 0.03    3.70    3.74    3.76    3.77    3.81
## lambda[44]    3.79    0.00 0.03    3.73    3.77    3.79    3.82    3.86
## lambda[45]    3.84    0.00 0.05    3.74    3.81    3.84    3.88    3.95
## lambda[46]    3.73    0.00 0.03    3.67    3.71    3.73    3.75    3.79
## lambda[47]    3.65    0.00 0.06    3.53    3.61    3.65    3.69    3.77
## lambda[48]    3.80    0.00 0.04    3.73    3.77    3.80    3.82    3.87
## lambda[49]    3.72    0.00 0.03    3.66    3.70    3.72    3.74    3.79
## lambda[50]    3.98    0.00 0.04    3.91    3.96    3.98    4.01    4.05
## lp__       6896.52    0.04 1.25 6893.33 6895.94 6896.86 6897.45 6897.95
##            n_eff Rhat
## b[1]        1373    1
## b[2]        1797    1
## b[3]        1422    1
## lambda[1]   1510    1
## lambda[2]   2086    1
## lambda[3]   2476    1
## lambda[4]   1902    1
## lambda[5]   1848    1
## lambda[6]   2477    1
## lambda[7]   2191    1
## lambda[8]   2044    1
## lambda[9]   2274    1
## lambda[10]  2293    1
## lambda[11]  2175    1
## lambda[12]  2431    1
## lambda[13]  2452    1
## lambda[14]  2191    1
## lambda[15]  2139    1
## lambda[16]  2111    1
## lambda[17]  2268    1
## lambda[18]  1632    1
## lambda[19]  1792    1
## lambda[20]  1882    1
## lambda[21]  1902    1
## lambda[22]  2291    1
## lambda[23]  2251    1
## lambda[24]  2130    1
## lambda[25]  2441    1
## lambda[26]  2436    1
## lambda[27]  2385    1
## lambda[28]  2385    1
## lambda[29]  2022    1
## lambda[30]  2243    1
## lambda[31]  1747    1
## lambda[32]  1775    1
## lambda[33]  2245    1
## lambda[34]  2469    1
## lambda[35]  2484    1
## lambda[36]  2490    1
## lambda[37]  2230    1
## lambda[38]  2165    1
## lambda[39]  1739    1
## lambda[40]  1722    1
## lambda[41]  2494    1
## lambda[42]  2477    1
## lambda[43]  2439    1
## lambda[44]  2236    1
## lambda[45]  1811    1
## lambda[46]  2040    1
## lambda[47]  1431    1
## lambda[48]  2199    1
## lambda[49]  1910    1
## lambda[50]  2111    1
## lp__        1212    1
## 
## Samples were drawn using NUTS(diag_e) at Fri Jul 12 00:36:58 2019.
## For each parameter, n_eff is a crude measure of effective sample size,
## and Rhat is the potential scale reduction factor on split chains (at 
## convergence, Rhat=1).

運行結果的解釋

...{省略}...
              mean se_mean   sd    2.5%     25%     50%     75%   97.5% n_eff Rhat
b[1]          3.58    0.00 0.09    3.38    3.51    3.58    3.64    3.76  1373    1
b[2]          0.26    0.00 0.04    0.18    0.24    0.26    0.29    0.35  1797    1
b[3]          0.29    0.00 0.15    0.00    0.20    0.29    0.39    0.59  1422    1
lambda[1]     3.68    0.00 0.05    3.58    3.65    3.68    3.71    3.77  1510    1
...{省略}...

我們把計算獲得的事後概率分佈均值放入前面寫下的數學表達式:

\[ \begin{aligned} \lambda[n] & = \exp(3.58 + 0.26A[n] + 0.29Score[n]/200) & n = 1, \dots, N \\ M[n] & \sim \text{Poisson}(\lambda[n]) & n = 1, \dots, N \end{aligned} \]

例如說,Score = 150Score = 50 的兩名學生,如果對打工喜好態度相同的話,他們之間選課的總課時數之比爲:

\[ \begin{aligned} \frac{M_\text{Score = 150}}{M_\text{Score = 50}} & = \frac{\exp(3.58 + 0.26A + 0.29\times\frac{150}{200})}{\exp(3.58 + 0.26A + 0.29\times\frac{50}{200})} \\ & = \exp(0.29\times\frac{150-50}{200}) \approx 1.16 \end{aligned} \]

也就是熱愛學習分數 Score 達到150的人和只有50的人相比,選課總課時數平均多 16%。相似地,喜歡打工 A = 1 的學生和不喜歡打工 A = 0 的學生選課總課時數之比爲 \(\exp(0.26)\approx1.30\)

Related

comments powered by Disqus