This is model m10.10stan.c
in Statistical Rethinking Edition 1. The model aims to show how centering predictors can reduce running time when parameters are highly correlated.
import CSV
import Random
import TuringModels
using DataFrames
using Statistics: mean
Random.seed!(1)
data_path = joinpath(TuringModels.project_root, "data", "Kline.csv")
df = CSV.read(data_path, DataFrame; delim=';')
df.log_pop = log.(df.population)
df.contact_high = [contact == "high" ? 1 : 0 for contact in df.contact]
10-element Vector{Int64}:
0
0
0
1
1
1
1
0
1
0
New col where we center(!) the log_pop values
mean_log_pop = mean(df.log_pop)
df.log_pop_c = map(x -> x - mean_log_pop, df.log_pop)
df
10×8 DataFrame
Row │ culture population contact total_tools mean_TU log_pop contact_high log_pop_c
│ String15 Int64 String7 Int64 Float64 Float64 Int64 Float64
─────┼───────────────────────────────────────────────────────────────────────────────────────────
1 │ Malekula 1100 low 13 3.2 7.00307 0 -1.97394
2 │ Tikopia 1500 low 22 4.7 7.31322 0 -1.66378
3 │ Santa Cruz 3600 low 24 4.0 8.18869 0 -0.788316
4 │ Yap 4791 high 43 5.0 8.47449 1 -0.50251
5 │ Lau Fiji 7400 high 33 5.0 8.90924 1 -0.0677695
6 │ Trobriand 8000 high 19 4.0 8.9872 1 0.0101921
7 │ Chuuk 9200 high 40 3.8 9.12696 1 0.149954
8 │ Manus 13000 low 28 6.6 9.4727 0 0.4957
9 │ Tonga 17500 high 55 5.4 9.76996 1 0.792951
10 │ Hawaii 275000 low 71 6.6 12.5245 0 3.54752
using Turing
@model function m10_10stan_c(total_tools, log_pop_c, contact_high)
α ~ Normal(0, 100)
βp ~ Normal(0, 1)
βc ~ Normal(0, 1)
βpc ~ Normal(0, 1)
for i ∈ 1:length(total_tools)
λ = exp(α + βp*log_pop_c[i] + βc*contact_high[i] +
βpc*contact_high[i]*log_pop_c[i])
total_tools[i] ~ Poisson(λ)
end
end;
chns = sample(
m10_10stan_c(df.total_tools, df.log_pop_c, df.contact_high),
NUTS(),
1000
)
Chains MCMC chain (1000×16×1 Array{Float64, 3}):
Iterations = 501:1:1500
Number of chains = 1
Samples per chain = 1000
Wall duration = 3.41 seconds
Compute duration = 3.41 seconds
parameters = α, βp, βc, βpc
internals = lp, n_steps, is_accept, acceptance_rate, log_density, hamiltonian_energy, hamiltonian_energy_error, max_hamiltonian_energy_error, tree_depth, numerical_error, step_size, nom_step_size
Summary Statistics
parameters mean std naive_se mcse ess rhat ess_per_sec
Symbol Float64 Float64 Float64 Float64 Float64 Float64 Float64
α 3.3080 0.0901 0.0028 0.0037 714.0783 1.0069 209.5300
βp 0.2623 0.0363 0.0011 0.0014 773.7110 0.9995 227.0279
βc 0.2876 0.1196 0.0038 0.0048 723.8301 1.0053 212.3915
βpc 0.0627 0.1746 0.0055 0.0054 803.7383 0.9990 235.8387
Quantiles
parameters 2.5% 25.0% 50.0% 75.0% 97.5%
Symbol Float64 Float64 Float64 Float64 Float64
α 3.1246 3.2481 3.3111 3.3717 3.4747
βp 0.1896 0.2388 0.2627 0.2865 0.3348
βc 0.0683 0.2035 0.2849 0.3672 0.5283
βpc -0.2845 -0.0591 0.0608 0.1812 0.4103
using StatsPlots
StatsPlots.plot(chns)
"/home/runner/work/TuringModels.jl/TuringModels.jl/__site/assets/models/centered-oceanic-tool-complexity/code/output/chns.svg"
m_10_10t_c_result = "
mean sd 5.5% 94.5% n_eff Rhat
a 3.31 0.09 3.17 3.45 3671 1
bp 0.26 0.03 0.21 0.32 5052 1
bc 0.28 0.12 0.10 0.47 3383 1
bcp 0.07 0.17 -0.20 0.34 4683 1
";