TuringModels

Centered oceanic tool complexity

This is model m10.10stan.c in Statistical Rethinking Edition 1. The model aims to show how centering predictors can reduce running time when parameters are highly correlated.

  1. Data
  2. Model
  3. Original output

Data

import CSV
import Random
import TuringModels

using DataFrames
using Statistics: mean

Random.seed!(1)

data_path = joinpath(TuringModels.project_root, "data", "Kline.csv")
df = CSV.read(data_path, DataFrame; delim=';')

df.log_pop = log.(df.population)
df.contact_high = [contact == "high" ? 1 : 0 for contact in df.contact]
10-element Vector{Int64}:
 0
 0
 0
 1
 1
 1
 1
 0
 1
 0

New col where we center(!) the log_pop values

mean_log_pop = mean(df.log_pop)
df.log_pop_c = map(x -> x - mean_log_pop, df.log_pop)
df
10×8 DataFrame
 Row │ culture     population  contact  total_tools  mean_TU  log_pop   contact_high  log_pop_c
     │ String15    Int64       String7  Int64        Float64  Float64   Int64         Float64
─────┼───────────────────────────────────────────────────────────────────────────────────────────
   1 │ Malekula          1100  low               13      3.2   7.00307             0  -1.97394
   2 │ Tikopia           1500  low               22      4.7   7.31322             0  -1.66378
   3 │ Santa Cruz        3600  low               24      4.0   8.18869             0  -0.788316
   4 │ Yap               4791  high              43      5.0   8.47449             1  -0.50251
   5 │ Lau Fiji          7400  high              33      5.0   8.90924             1  -0.0677695
   6 │ Trobriand         8000  high              19      4.0   8.9872              1   0.0101921
   7 │ Chuuk             9200  high              40      3.8   9.12696             1   0.149954
   8 │ Manus            13000  low               28      6.6   9.4727              0   0.4957
   9 │ Tonga            17500  high              55      5.4   9.76996             1   0.792951
  10 │ Hawaii          275000  low               71      6.6  12.5245              0   3.54752

Model

using Turing

@model function m10_10stan_c(total_tools, log_pop_c, contact_high)
    α ~ Normal(0, 100)
    βp ~ Normal(0, 1)
    βc ~ Normal(0, 1)
    βpc ~ Normal(0, 1)

    for i ∈ 1:length(total_tools)
        λ = exp(α + βp*log_pop_c[i] + βc*contact_high[i] +
            βpc*contact_high[i]*log_pop_c[i])
        total_tools[i] ~ Poisson(λ)
    end
end;

chns = sample(
    m10_10stan_c(df.total_tools, df.log_pop_c, df.contact_high),
    NUTS(),
    1000
)
Chains MCMC chain (1000×16×1 Array{Float64, 3}):

Iterations        = 501:1:1500
Number of chains  = 1
Samples per chain = 1000
Wall duration     = 3.41 seconds
Compute duration  = 3.41 seconds
parameters        = α, βp, βc, βpc
internals         = lp, n_steps, is_accept, acceptance_rate, log_density, hamiltonian_energy, hamiltonian_energy_error, max_hamiltonian_energy_error, tree_depth, numerical_error, step_size, nom_step_size

Summary Statistics
  parameters      mean       std   naive_se      mcse        ess      rhat   ess_per_sec
      Symbol   Float64   Float64    Float64   Float64    Float64   Float64       Float64

           α    3.3080    0.0901     0.0028    0.0037   714.0783    1.0069      209.5300
          βp    0.2623    0.0363     0.0011    0.0014   773.7110    0.9995      227.0279
          βc    0.2876    0.1196     0.0038    0.0048   723.8301    1.0053      212.3915
         βpc    0.0627    0.1746     0.0055    0.0054   803.7383    0.9990      235.8387

Quantiles
  parameters      2.5%     25.0%     50.0%     75.0%     97.5%
      Symbol   Float64   Float64   Float64   Float64   Float64

           α    3.1246    3.2481    3.3111    3.3717    3.4747
          βp    0.1896    0.2388    0.2627    0.2865    0.3348
          βc    0.0683    0.2035    0.2849    0.3672    0.5283
         βpc   -0.2845   -0.0591    0.0608    0.1812    0.4103

using StatsPlots

StatsPlots.plot(chns)
"/home/runner/work/TuringModels.jl/TuringModels.jl/__site/assets/models/centered-oceanic-tool-complexity/code/output/chns.svg"

Original output

m_10_10t_c_result = "
    mean   sd  5.5% 94.5% n_eff Rhat
 a   3.31 0.09  3.17  3.45  3671    1
 bp  0.26 0.03  0.21  0.32  5052    1
 bc  0.28 0.12  0.10  0.47  3383    1
 bcp 0.07 0.17 -0.20  0.34  4683    1
";