
【Rでベイズ統計モデリング#15】ランダム係数モデル(brms)
記事の目的
GLMであるランダム係数モデルのベイズ推定を、RとStanを使用して実装していきます。今回は、「brms」というライブラリを使用します。データの作成から実装するので、コピペで再現することが可能です。
目次
0 前準備
0.1 今回のモデル
0.2 ワーキングディレクトリの設定
以下の画像のようにワーキングディレクトリを設定します。設定したディレクトリに、RファイルとStanファイルを保存します。
1 ライブラリ
x
10
10
1
# 1 ライブラリ
2
library(dplyr)
3
library(ggplot2)
4
library(rstan)
5
library(brms)
6
library(patchwork)
7
8
set.seed(1)
9
rstan_options(auto_write=TRUE)
10
options(mc.cores=parallel::detectCores())
2 データ
2.1 コード
1
22
22
1
# 2 データ
2
気温 <- rnorm(100,20,5) %>% round(1)
3
休日 <- rbinom(100, 1, 2/7)
4
店舗 <- runif(100, 1, 4) %>% round()
5
r1 <- rnorm(4, 0, 0.6)
6
r2 <- rnorm(4, 0, 0.4)
7
data <- data.frame(気温, 休日, 店舗) %>%
8
mutate(r1=ifelse(店舗==1, r1[1], ifelse(店舗==2, r1[2], ifelse(店舗==3, r1[3],r1[4]))))%>%
9
mutate(r2=ifelse(店舗==1, r2[1], ifelse(店舗==2, r2[2], ifelse(店舗==3, r2[3],r2[4]))))%>%
10
mutate(lambda=exp(-2+0.2*気温+(0.5+r1)*休日+r2)) %>%
11
mutate(休日=factor(休日))
12
data$売り上げ個数 <- rpois(100, data$lambda)
13
data %>% select(気温, 休日, 店舗, 売り上げ個数) %>% head()
14
15
plot <- data %>%
16
ggplot(aes(x=気温, y=売り上げ個数, color=休日)) +
17
geom_point() +
18
theme_classic(base_family = "HiraKakuPro-W3") +
19
theme(text=element_text(size=25))+
20
labs(x="気温", y="売り上げ個数", title="データ") +
21
facet_wrap(.~ 店舗)
22
plot
2.2 結果
13行目の結果 | 22行目の結果 |
![]() |
![]() |
3 brmsの利用
1
8
1
# 3 brmsの使用
2
mcmc_result <- brm(
3
data = data,
4
formula = 売り上げ個数~ 気温 + 休日 + (休日||店舗),
5
family = poisson(),
6
seed = 1,
7
iter = 2000, warmup = 200, chains = 4, thin=1
8
)
4 分析結果
4.1 コード
1
19
19
1
# 4 分析結果
2
## 4.1 推定結果
3
print(mcmc_result)
4
ranef(mcmc_result)
5
6
## 4.2 収束の確認
7
theme_set(theme_classic(base_size = 10, base_family = "HiraKakuProN-W3"))
8
plot(mcmc_result)
9
10
## 4.3 λの確認
11
condition <- data.frame(店舗=1:4)
12
plot(conditional_effects(mcmc_result, effects="気温:休日",re_formula=NULL,
13
conditions = condition), points=TRUE, ncol=2) %>%
14
wrap_plots() + plot_annotation(title="λの推定結果")
15
16
## 4.4 予測分布
17
plot(conditional_effects(mcmc_result, effects="気温:休日",re_formula=NULL,
18
conditions = condition, method="predict"), points=TRUE, ncol=2)%>%
19
wrap_plots() + plot_annotation(title="予測分布")
4.2 結果
3行目の結果 | 8行目の結果 |
![]() |
![]() |
12行目の結果 | 17行目の結果 |
![]() |
![]() |