TuringModels

Over-dispersed Oceanic tool complexity

This is model m12.6 in Statistical Rethinking Edition 1.

  1. Data
  2. Model
  3. Output
  4. Original output

Data

import CSV
import Random

using DataFrames
using TuringModels: project_root

Random.seed!(1)

path = joinpath(project_root, "data", "Kline.csv")
df = CSV.read(path, DataFrame; delim=';')
df.log_pop = log.(df.population)
df.society = 1:nrow(df)
df
10×7 DataFrame
 Row │ culture     population  contact    total_tools  mean_TU  log_pop   society
     │ InlineSt…   Int64       InlineSt…  Int64        Float64  Float64   Int64
─────┼────────────────────────────────────────────────────────────────────────────
   1 │ Malekula          1100  low                 13      3.2   7.00307        1
   2 │ Tikopia           1500  low                 22      4.7   7.31322        2
   3 │ Santa Cruz        3600  low                 24      4.0   8.18869        3
   4 │ Yap               4791  high                43      5.0   8.47449        4
   5 │ Lau Fiji          7400  high                33      5.0   8.90924        5
   6 │ Trobriand         8000  high                19      4.0   8.9872         6
   7 │ Chuuk             9200  high                40      3.8   9.12696        7
   8 │ Manus            13000  low                 28      6.6   9.4727         8
   9 │ Tonga            17500  high                55      5.4   9.76996        9
  10 │ Hawaii          275000  low                 71      6.6  12.5245        10

Model

using Turing

@model function m12_6(total_tools, log_pop, society)
    N = length(total_tools)

    α ~ Normal(0, 10)
    βp ~ Normal(0, 1)

    σ_society ~ truncated(Cauchy(0, 1), 0, Inf)

    N_society = length(unique(society)) ## 10

    α_society ~ filldist(Normal(0, σ_society), N_society)

    for i in 1:N
        λ = exp(α + α_society[society[i]] + βp*log_pop[i])
        total_tools[i] ~ Poisson(λ)
    end
end;

Output

chns = sample(
    m12_6(df.total_tools, df.log_pop, df.society),
    NUTS(0.95),
    1000
)
Chains MCMC chain (1000×25×1 Array{Float64, 3}):

Iterations        = 501:1:1500
Number of chains  = 1
Samples per chain = 1000
Wall duration     = 10.21 seconds
Compute duration  = 10.21 seconds
parameters        = α, α_society[9], α_society[8], α_society[1], α_society[3], α_society[10], α_society[5], α_society[7], σ_society, α_society[6], α_society[2], βp, α_society[4]
internals         = lp, n_steps, is_accept, acceptance_rate, log_density, hamiltonian_energy, hamiltonian_energy_error, max_hamiltonian_energy_error, tree_depth, numerical_error, step_size, nom_step_size

Summary Statistics
     parameters      mean       std   naive_se      mcse        ess      rhat   ess_per_sec
         Symbol   Float64   Float64    Float64   Float64    Float64   Float64       Float64

              α    1.1688    0.6922     0.0219    0.0422   260.6561    1.0047       25.5270
             βp    0.2540    0.0759     0.0024    0.0046   266.3444    1.0036       26.0841
      σ_society    0.3237    0.1331     0.0042    0.0077   190.3927    0.9996       18.6458
   α_society[1]   -0.2325    0.2484     0.0079    0.0121   387.8535    0.9993       37.9839
   α_society[2]    0.0323    0.2143     0.0068    0.0111   418.7738    1.0035       41.0120
   α_society[3]   -0.0520    0.2029     0.0064    0.0100   487.5709    1.0052       47.7496
   α_society[4]    0.3286    0.1953     0.0062    0.0098   277.0484    1.0051       27.1323
   α_society[5]    0.0375    0.1813     0.0057    0.0083   496.7051    1.0035       48.6441
   α_society[6]   -0.3321    0.2146     0.0068    0.0110   369.8285    0.9990       36.2186
   α_society[7]    0.1434    0.1745     0.0055    0.0074   482.4804    1.0021       47.2510
   α_society[8]   -0.1881    0.1933     0.0061    0.0093   446.6104    0.9994       43.7382
   α_society[9]    0.2771    0.1842     0.0058    0.0095   370.3378    0.9998       36.2685
  α_society[10]   -0.0743    0.2863     0.0091    0.0161   302.9755    1.0007       29.6715

Quantiles
     parameters      2.5%     25.0%     50.0%     75.0%     97.5%
         Symbol   Float64   Float64   Float64   Float64   Float64

              α   -0.2201    0.7466    1.2131    1.6150    2.5312
             βp    0.1170    0.2039    0.2502    0.3000    0.4158
      σ_society    0.1209    0.2323    0.3050    0.3884    0.6423
   α_society[1]   -0.7531   -0.3841   -0.2134   -0.0634    0.1858
   α_society[2]   -0.4063   -0.1000    0.0336    0.1585    0.4673
   α_society[3]   -0.4878   -0.1743   -0.0363    0.0761    0.3444
   α_society[4]   -0.0482    0.1938    0.3282    0.4578    0.7136
   α_society[5]   -0.3088   -0.0782    0.0350    0.1523    0.4055
   α_society[6]   -0.7949   -0.4626   -0.3137   -0.1801    0.0310
   α_society[7]   -0.1799    0.0227    0.1293    0.2537    0.4878
   α_society[8]   -0.5970   -0.3033   -0.1767   -0.0628    0.1623
   α_society[9]   -0.0802    0.1502    0.2768    0.3935    0.6670
  α_society[10]   -0.7015   -0.2424   -0.0593    0.1151    0.4698

using StatsPlots

StatsPlots.plot(chns)

Original output

m12_6rethinking = "
              Mean StdDev lower 0.89 upper 0.89 n_eff Rhat
a              1.11   0.75      -0.05       2.24  1256    1
bp             0.26   0.08       0.13       0.38  1276    1
a_society[1]  -0.20   0.24      -0.57       0.16  2389    1
a_society[2]   0.04   0.21      -0.29       0.38  2220    1
a_society[3]  -0.05   0.19      -0.36       0.25  3018    1
a_society[4]   0.32   0.18       0.01       0.60  2153    1
a_society[5]   0.04   0.18      -0.22       0.33  3196    1
a_society[6]  -0.32   0.21      -0.62       0.02  2574    1
a_society[7]   0.14   0.17      -0.13       0.40  2751    1
a_society[8]  -0.18   0.19      -0.46       0.12  2952    1
a_society[9]   0.27   0.17      -0.02       0.52  2540    1
a_society[10] -0.10   0.30      -0.52       0.37  1433    1
sigma_society  0.31   0.13       0.11       0.47  1345    1
";