m10.04d1

Load Julia packages (libraries) needed for the snippets in chapter 0

using DynamicHMCModels, ForwardDiff, Flux, ReverseDiff
gr(size=(400,400))

CmdStan uses a tmp directory to store the output of cmdstan

ProjDir = rel_path_d("..", "scripts", "10")
cd(ProjDir)

snippet 10.4

d = CSV.read(rel_path("..", "data", "chimpanzees.csv"), delim=';');
df = convert(DataFrame, d);
df[!, :pulled_left] = convert(Array{Int64}, df[!, :pulled_left])
df[!, :prosoc_left] = convert(Array{Int64}, df[!, :prosoc_left])
df[!, :condition] = convert(Array{Int64}, df[!, :condition])
df[!, :actor] = convert(Array{Int64}, df[!, :actor])
first(df[!, [:actor, :pulled_left, :prosoc_left, :condition]], 5)

struct m_10_04d_model{TY <: AbstractVector, TX <: AbstractMatrix,
  TA <: AbstractVector}
    "Observations."
    y::TY
    "Covariates"
    X::TX
    "Actors"
    A::TA
    "Number of observations"
    N::Int
    "Number of unique actors"
    N_actors::Int
end

Make the type callable with the parameters as a single argument.

function (problem::m_10_04d_model)(θ)
    @unpack y, X, A, N, N_actors = problem   # extract the data
    @unpack β, α = θ  # works on the named tuple too
    ll = 0.0
    ll += sum(logpdf.(Normal(0, 10), β)) # bp & bpC
    ll += sum(logpdf.(Normal(0, 10), α)) # alpha[1:7]
    ll += sum(
      [loglikelihood(Binomial(1, logistic(α[A[i]] + dot(X[i, :], β))), [y[i]]) for i in 1:N]
    )
    ll
end

Instantiate the model with data and inits.

N = size(df, 1)
N_actors = length(unique(df[!, :actor]))
X = hcat(ones(Int64, N), df[!, :prosoc_left] .* df[!, :condition]);
A = df[!, :actor]
y = df[!, :pulled_left]
p = m_10_04d_model(y, X, A, N, N_actors);
θ = (β = [1.0, 0.0], α = [-1.0, 10.0, -1.0, -1.0, -1.0, 0.0, 2.0])
p(θ)
-305.21943396408915

Write a function to return properly dimensioned transformation.

problem_transformation(p::m_10_04d_model) =
    as( (β = as(Array, size(p.X, 2)), α = as(Array, p.N_actors), ) )
problem_transformation (generic function with 1 method)

Wrap the problem with a transformation, then use Flux for the gradient.

P = TransformedLogDensity(problem_transformation(p), p)
TransformedLogDensity of dimension 9

For stress testing

do_stresstest = false

#ad = :Flux
ad = :ForwardDiff
#ad = :ReverseDiff

if do_stresstest
  ∇P = ADgradient(:ForwardDiff, P);
  #st = LogDensityProblems.stresstest(p, N=1000, scale=1.0)
  #display(st)
else
  ∇P = LogDensityRejectErrors(ADgradient(ad, P));
end
LogDensityRejectErrors{InvalidLogDensityException,LogDensityProblems.ForwardDiffLogDensity{TransformedLogDensity{TransformVariables.TransformTuple{NamedTuple{(:β, :α),Tuple{TransformVariables.ArrayTransform{TransformVariables.Identity,1},TransformVariables.ArrayTransform{TransformVariables.Identity,1}}}},Main.ex-m10.04d1.m_10_04d_model{Array{Int64,1},Array{Int64,2},Array{Int64,1}}},ForwardDiff.GradientConfig{ForwardDiff.Tag{getfield(LogDensityProblems, Symbol("##1#2")){TransformedLogDensity{TransformVariables.TransformTuple{NamedTuple{(:β, :α),Tuple{TransformVariables.ArrayTransform{TransformVariables.Identity,1},TransformVariables.ArrayTransform{TransformVariables.Identity,1}}}},Main.ex-m10.04d1.m_10_04d_model{Array{Int64,1},Array{Int64,2},Array{Int64,1}}}},Float64},Float64,9,Array{ForwardDiff.Dual{ForwardDiff.Tag{getfield(LogDensityProblems, Symbol("##1#2")){TransformedLogDensity{TransformVariables.TransformTuple{NamedTuple{(:β, :α),Tuple{TransformVariables.ArrayTransform{TransformVariables.Identity,1},TransformVariables.ArrayTransform{TransformVariables.Identity,1}}}},Main.ex-m10.04d1.m_10_04d_model{Array{Int64,1},Array{Int64,2},Array{Int64,1}}}},Float64},Float64,9},1}}}}(ForwardDiff AD wrapper for TransformedLogDensity of dimension 9, w/ chunk size 9)

Run single chains

chain, NUTS_tuned = NUTS_init_tune_mcmc(∇P, 3000);
posterior = TransformVariables.transform.(Ref(problem_transformation(p)),
  get_position.(chain));
3000-element Array{NamedTuple{(:β, :α),Tuple{Array{Float64,1},Array{Float64,1}}},1}:
 (β = [6.603461470200594, 0.8731462370599579], α = [-7.108030781372997, 1.8307779401384146, -7.276260143922696, -7.4718435131187055, -7.2384292211391275, -6.1649915625305525, -4.125960780217898])      
 (β = [-0.7948172660830137, 0.6707858315419812], α = [0.3931919250653251, 5.832839090557499, 0.44319935382797637, -0.030395668273310417, 0.7115302260133931, 1.5930798225672216, 2.4353518149936937])    
 (β = [2.687698907629545, 0.5487036657923703], α = [-3.17527342404502, 5.809970680688019, -3.2135822899545845, -3.560438193856329, -3.146533037299935, -2.3467288630892584, -0.22285287169527868])       
 (β = [0.28964576292792305, 0.26200256781720066], α = [-0.6037962831937097, 6.753081807117683, -1.268151478332617, -0.9932418245163587, -0.6165359399547453, 0.3001032330036401, 1.802083888996083])     
 (β = [1.0663916155625817, 0.27052340186628043], α = [-1.149657005976225, 4.318711960837669, -2.289855770193022, -1.843989783157901, -1.169469340983003, -0.7710184904198361, 1.0500704414502764])       
 (β = [0.0633040746966457, 0.8293932657771456], α = [-0.4662473315656014, 11.76588311972391, -0.015144358228897425, -1.1453825361766168, -0.45624772123423485, -0.09958820926557149, 1.8028343374448013])
 (β = [-2.9317798530718613, -0.19910991890304539], α = [2.941031684815498, 9.231653403080692, 2.309144017406354, 2.6236501576337368, 2.418204297671485, 4.2735412695644595, 5.316080519039785])          
 (β = [-0.00563546901036438, 0.3524461102457531], α = [-0.3646786377100779, 5.231746799736418, -0.8872941829006613, -1.2667450524908894, -0.7085674651343723, 0.6291095032749097, 2.1755854465258038])   
 (β = [3.5305064448373233, 0.48816493389705795], α = [-4.157440107778783, 11.522096709354116, -4.247009127558339, -4.0669296486037005, -3.7962478626608958, -3.2522713365104843, -1.4268135181980763])   
 (β = [4.553034118360961, 0.29303919379535226], α = [-4.778227544619665, 1.8899159249011892, -5.308375516687748, -5.432403655646659, -5.1791963076248635, -3.8935302407977925, -2.98315469485819])       
 ⋮                                                                                                                                                                                                       
 (β = [5.689145213467929, 0.24877653431359764], α = [-6.30538170349988, -0.4201652397012483, -6.277204093516436, -6.052252626670219, -5.995663014264362, -5.872470706984473, -3.5971942096960547])       
 (β = [4.069983286726469, 0.6997548844232694], α = [-4.518476869077594, 9.952455234933097, -5.111009884746467, -5.2040309747577265, -4.700546166942033, -3.3729383844087977, -2.5900368094760626])       
 (β = [2.3193347423092403, 0.3540741661083513], α = [-2.389765977572349, 8.690020509223096, -3.345951299261018, -2.8042877347425725, -2.846234996858588, -1.5577392031449118, -0.40230156596675504])     
 (β = [1.9321838440305492, -0.02001143019217619], α = [-2.21659683539953, 9.71724329262789, -2.7882687185120325, -2.5263614424374072, -2.2363227738699116, -1.2052371766075038, 0.3777353231370431])     
 (β = [-2.1161935330414097, 0.7959093851807156], α = [1.4820105966243653, 9.294439951779358, 1.4061825972819755, 1.2746585261810242, 1.5361036414642935, 2.282388449345532, 3.537656624017777])          
 (β = [3.669314430450167, -0.06306644610609036], α = [-4.147635303204296, 3.3369933956759406, -4.4127330272205265, -4.422511205651178, -4.124515563752176, -2.885341882698689, -0.9158591100640103])     
 (β = [2.280433991815419, 0.9117305557116371], α = [-2.650083224042836, 12.603317772892204, -2.9032977596106004, -2.942186298789138, -2.600623558143786, -2.010136913287744, -0.6311534406684174])       
 (β = [2.003245353258863, 0.926217040399798], α = [-2.413574158644768, 8.34411037143406, -2.9231743507563928, -3.067115709977377, -2.9188150506668737, -1.728206841094448, -0.06603179632919637])        
 (β = [3.326379886991486, -0.14258946327630564], α = [-4.433593144273074, 5.7658390843735585, -3.8533162910382037, -3.7678734903798468, -3.6086328208071556, -2.7183218261231548, -0.6420130314349701])  

Result rethinking

rethinking = "
      mean   sd  5.5% 94.5% n_eff Rhat
a[1] -0.74 0.27 -1.19 -0.31  2899    1
a[2] 10.77 5.20  4.60 20.45  1916    1
a[3] -1.05 0.28 -1.50 -0.62  3146    1
a[4] -1.05 0.28 -1.50 -0.61  3525    1
a[5] -0.73 0.28 -1.17 -0.28  3637    1
a[6]  0.22 0.27 -0.21  0.67  3496    1
a[7]  1.82 0.41  1.21  2.50  3202    1
bp    0.83 0.27  0.42  1.27  2070    1
bpC  -0.13 0.31 -0.62  0.34  3430    1
";
"\n      mean   sd  5.5% 94.5% n_eff Rhat\na[1] -0.74 0.27 -1.19 -0.31  2899    1\na[2] 10.77 5.20  4.60 20.45  1916    1\na[3] -1.05 0.28 -1.50 -0.62  3146    1\na[4] -1.05 0.28 -1.50 -0.61  3525    1\na[5] -0.73 0.28 -1.17 -0.28  3637    1\na[6]  0.22 0.27 -0.21  0.67  3496    1\na[7]  1.82 0.41  1.21  2.50  3202    1\nbp    0.83 0.27  0.42  1.27  2070    1\nbpC  -0.13 0.31 -0.62  0.34  3430    1\n"

Set varable names, this will be automated using θ

parameter_names = ["bp", "bpC"]
pooled_parameter_names = ["a[$i]" for i in 1:7];
7-element Array{String,1}:
 "a[1]"
 "a[2]"
 "a[3]"
 "a[4]"
 "a[5]"
 "a[6]"
 "a[7]"

Create a3d

a3d = Array{Float64, 3}(undef, 3000, 9, 1);
for i in 1:3000
  a3d[i, 1:2, 1] = values(posterior[i][1])
  a3d[i, 3:9, 1] = values(posterior[i][2])
end

chns = MCMCChains.Chains(a3d,
  vcat(parameter_names, pooled_parameter_names),
  Dict(
    :parameters => parameter_names,
    :pooled => pooled_parameter_names
  )
);
Object of type Chains, with data of type 3000×9×1 Array{Float64,3}

Iterations        = 1:3000
Thinning interval = 1
Chains            = 1
Samples per chain = 3000
pooled            = a[1], a[2], a[3], a[4], a[5], a[6], a[7]
parameters        = bp, bpC

2-element Array{ChainDataFrame,1}

Summary Statistics
. Omitted printing of 1 columns
│ Row │ parameters │ mean     │ std      │ naive_se   │ mcse       │ ess     │
│     │ Symbol     │ Float64  │ Float64  │ Float64    │ Float64    │ Any     │
├─────┼────────────┼──────────┼──────────┼────────────┼────────────┼─────────┤
│ 1   │ bp         │ 1.52165  │ 3.63396  │ 0.0663466  │ 0.0957052  │ 1134.66 │
│ 2   │ bpC        │ 0.412055 │ 0.245361 │ 0.00447967 │ 0.00449436 │ 3000.0  │

Quantiles

│ Row │ parameters │ 2.5%       │ 25.0%     │ 50.0%    │ 75.0%    │ 97.5%    │
│     │ Symbol     │ Float64    │ Float64   │ Float64  │ Float64  │ Float64  │
├─────┼────────────┼────────────┼───────────┼──────────┼──────────┼──────────┤
│ 1   │ bp         │ -5.75568   │ -0.894522 │ 1.52198  │ 3.91638  │ 8.75534  │
│ 2   │ bpC        │ -0.0756923 │ 0.245969  │ 0.407616 │ 0.575444 │ 0.898349 │

Describe the chain

describe(chns)
2-element Array{ChainDataFrame,1}

Summary Statistics
. Omitted printing of 1 columns
│ Row │ parameters │ mean     │ std      │ naive_se   │ mcse       │ ess     │
│     │ Symbol     │ Float64  │ Float64  │ Float64    │ Float64    │ Any     │
├─────┼────────────┼──────────┼──────────┼────────────┼────────────┼─────────┤
│ 1   │ bp         │ 1.52165  │ 3.63396  │ 0.0663466  │ 0.0957052  │ 1134.66 │
│ 2   │ bpC        │ 0.412055 │ 0.245361 │ 0.00447967 │ 0.00449436 │ 3000.0  │

Quantiles

│ Row │ parameters │ 2.5%       │ 25.0%     │ 50.0%    │ 75.0%    │ 97.5%    │
│     │ Symbol     │ Float64    │ Float64   │ Float64  │ Float64  │ Float64  │
├─────┼────────────┼────────────┼───────────┼──────────┼──────────┼──────────┤
│ 1   │ bp         │ -5.75568   │ -0.894522 │ 1.52198  │ 3.91638  │ 8.75534  │
│ 2   │ bpC        │ -0.0756923 │ 0.245969  │ 0.407616 │ 0.575444 │ 0.898349 │

Describe the chain

describe(chns, sections=[:pooled])
2-element Array{ChainDataFrame,1}

Summary Statistics
. Omitted printing of 1 columns
│ Row │ parameters │ mean     │ std     │ naive_se  │ mcse      │ ess     │
│     │ Symbol     │ Float64  │ Float64 │ Float64   │ Float64   │ Any     │
├─────┼────────────┼──────────┼─────────┼───────────┼───────────┼─────────┤
│ 1   │ a[1]       │ -1.97781 │ 3.64504 │ 0.066549  │ 0.0961998 │ 1147.36 │
│ 2   │ a[2]       │ 9.85265  │ 5.92061 │ 0.108095  │ 0.174025  │ 972.633 │
│ 3   │ a[3]       │ -2.27552 │ 3.64414 │ 0.0665326 │ 0.0954322 │ 1147.57 │
│ 4   │ a[4]       │ -2.27588 │ 3.63597 │ 0.0663833 │ 0.0954781 │ 1137.75 │
│ 5   │ a[5]       │ -1.97114 │ 3.64342 │ 0.0665194 │ 0.0955994 │ 1134.28 │
│ 6   │ a[6]       │ -1.03457 │ 3.64655 │ 0.0665765 │ 0.0952521 │ 1139.12 │
│ 7   │ a[7]       │ 0.53125  │ 3.6466  │ 0.0665775 │ 0.0972008 │ 1123.3  │

Quantiles

│ Row │ parameters │ 2.5%      │ 25.0%    │ 50.0%    │ 75.0%    │ 97.5%   │
│     │ Symbol     │ Float64   │ Float64  │ Float64  │ Float64  │ Float64 │
├─────┼────────────┼───────────┼──────────┼──────────┼──────────┼─────────┤
│ 1   │ a[1]       │ -9.22579  │ -4.31753 │ -2.01708 │ 0.4579   │ 5.28433 │
│ 2   │ a[2]       │ -0.251699 │ 5.69313  │ 9.31642  │ 13.447   │ 23.5678 │
│ 3   │ a[3]       │ -9.5852   │ -4.63254 │ -2.27662 │ 0.160106 │ 4.97501 │
│ 4   │ a[4]       │ -9.48227  │ -4.6551  │ -2.31992 │ 0.163695 │ 5.06924 │
│ 5   │ a[5]       │ -9.2475   │ -4.31611 │ -1.98012 │ 0.503501 │ 5.3256  │
│ 6   │ a[6]       │ -8.25119  │ -3.40936 │ -1.04996 │ 1.37006  │ 6.37548 │
│ 7   │ a[7]       │ -6.7622   │ -1.83372 │ 0.499678 │ 2.92218  │ 7.83578 │

Plot the chain parameters

plot(chns)
0 1000 2000 3000 -10 -5 0 5 10 bp Iteration Sample value -15 -10 -5 0 5 10 15 0.000 0.025 0.050 0.075 0.100 bp Sample value Density 0 1000 2000 3000 -0.5 0.0 0.5 1.0 bpC Iteration Sample value -0.5 0.0 0.5 1.0 1.5 0.0 0.5 1.0 1.5 bpC Sample value Density

End of m10.04d1.jl

This page was generated using Literate.jl.