Load Julia packages (libraries) needed for the snippets in chapter 0
using DynamicHMCModels, ForwardDiff, Flux, ReverseDiff
gr(size=(400,400))
CmdStan uses a tmp directory to store the output of cmdstan
ProjDir = rel_path_d("..", "scripts", "10")
cd(ProjDir)
snippet 10.4
d = CSV.read(rel_path("..", "data", "chimpanzees.csv"), delim=';');
df = convert(DataFrame, d);
df[!, :pulled_left] = convert(Array{Int64}, df[!, :pulled_left])
df[!, :prosoc_left] = convert(Array{Int64}, df[!, :prosoc_left])
df[!, :condition] = convert(Array{Int64}, df[!, :condition])
df[!, :actor] = convert(Array{Int64}, df[!, :actor])
first(df[!, [:actor, :pulled_left, :prosoc_left, :condition]], 5)
struct m_10_04d_model{TY <: AbstractVector, TX <: AbstractMatrix,
TA <: AbstractVector}
"Observations."
y::TY
"Covariates"
X::TX
"Actors"
A::TA
"Number of observations"
N::Int
"Number of unique actors"
N_actors::Int
end
Make the type callable with the parameters as a single argument.
function (problem::m_10_04d_model)(θ)
@unpack y, X, A, N, N_actors = problem # extract the data
@unpack β, α = θ # works on the named tuple too
ll = 0.0
ll += sum(logpdf.(Normal(0, 10), β)) # bp & bpC
ll += sum(logpdf.(Normal(0, 10), α)) # alpha[1:7]
ll += sum(
[loglikelihood(Binomial(1, logistic(α[A[i]] + dot(X[i, :], β))), [y[i]]) for i in 1:N]
)
ll
end
Instantiate the model with data and inits.
N = size(df, 1)
N_actors = length(unique(df[!, :actor]))
X = hcat(ones(Int64, N), df[!, :prosoc_left] .* df[!, :condition]);
A = df[!, :actor]
y = df[!, :pulled_left]
p = m_10_04d_model(y, X, A, N, N_actors);
θ = (β = [1.0, 0.0], α = [-1.0, 10.0, -1.0, -1.0, -1.0, 0.0, 2.0])
p(θ)
-305.21943396408915
Write a function to return properly dimensioned transformation.
problem_transformation(p::m_10_04d_model) =
as( (β = as(Array, size(p.X, 2)), α = as(Array, p.N_actors), ) )
problem_transformation (generic function with 1 method)
Wrap the problem with a transformation, then use Flux for the gradient.
P = TransformedLogDensity(problem_transformation(p), p)
TransformedLogDensity of dimension 9
For stress testing
do_stresstest = false
#ad = :Flux
ad = :ForwardDiff
#ad = :ReverseDiff
if do_stresstest
∇P = ADgradient(:ForwardDiff, P);
#st = LogDensityProblems.stresstest(p, N=1000, scale=1.0)
#display(st)
else
∇P = LogDensityRejectErrors(ADgradient(ad, P));
end
LogDensityRejectErrors{InvalidLogDensityException,LogDensityProblems.ForwardDiffLogDensity{TransformedLogDensity{TransformVariables.TransformTuple{NamedTuple{(:β, :α),Tuple{TransformVariables.ArrayTransform{TransformVariables.Identity,1},TransformVariables.ArrayTransform{TransformVariables.Identity,1}}}},Main.ex-m10.04d1.m_10_04d_model{Array{Int64,1},Array{Int64,2},Array{Int64,1}}},ForwardDiff.GradientConfig{ForwardDiff.Tag{getfield(LogDensityProblems, Symbol("##1#2")){TransformedLogDensity{TransformVariables.TransformTuple{NamedTuple{(:β, :α),Tuple{TransformVariables.ArrayTransform{TransformVariables.Identity,1},TransformVariables.ArrayTransform{TransformVariables.Identity,1}}}},Main.ex-m10.04d1.m_10_04d_model{Array{Int64,1},Array{Int64,2},Array{Int64,1}}}},Float64},Float64,9,Array{ForwardDiff.Dual{ForwardDiff.Tag{getfield(LogDensityProblems, Symbol("##1#2")){TransformedLogDensity{TransformVariables.TransformTuple{NamedTuple{(:β, :α),Tuple{TransformVariables.ArrayTransform{TransformVariables.Identity,1},TransformVariables.ArrayTransform{TransformVariables.Identity,1}}}},Main.ex-m10.04d1.m_10_04d_model{Array{Int64,1},Array{Int64,2},Array{Int64,1}}}},Float64},Float64,9},1}}}}(ForwardDiff AD wrapper for TransformedLogDensity of dimension 9, w/ chunk size 9)
Run single chains
chain, NUTS_tuned = NUTS_init_tune_mcmc(∇P, 3000);
posterior = TransformVariables.transform.(Ref(problem_transformation(p)),
get_position.(chain));
3000-element Array{NamedTuple{(:β, :α),Tuple{Array{Float64,1},Array{Float64,1}}},1}:
(β = [6.603461470200594, 0.8731462370599579], α = [-7.108030781372997, 1.8307779401384146, -7.276260143922696, -7.4718435131187055, -7.2384292211391275, -6.1649915625305525, -4.125960780217898])
(β = [-0.7948172660830137, 0.6707858315419812], α = [0.3931919250653251, 5.832839090557499, 0.44319935382797637, -0.030395668273310417, 0.7115302260133931, 1.5930798225672216, 2.4353518149936937])
(β = [2.687698907629545, 0.5487036657923703], α = [-3.17527342404502, 5.809970680688019, -3.2135822899545845, -3.560438193856329, -3.146533037299935, -2.3467288630892584, -0.22285287169527868])
(β = [0.28964576292792305, 0.26200256781720066], α = [-0.6037962831937097, 6.753081807117683, -1.268151478332617, -0.9932418245163587, -0.6165359399547453, 0.3001032330036401, 1.802083888996083])
(β = [1.0663916155625817, 0.27052340186628043], α = [-1.149657005976225, 4.318711960837669, -2.289855770193022, -1.843989783157901, -1.169469340983003, -0.7710184904198361, 1.0500704414502764])
(β = [0.0633040746966457, 0.8293932657771456], α = [-0.4662473315656014, 11.76588311972391, -0.015144358228897425, -1.1453825361766168, -0.45624772123423485, -0.09958820926557149, 1.8028343374448013])
(β = [-2.9317798530718613, -0.19910991890304539], α = [2.941031684815498, 9.231653403080692, 2.309144017406354, 2.6236501576337368, 2.418204297671485, 4.2735412695644595, 5.316080519039785])
(β = [-0.00563546901036438, 0.3524461102457531], α = [-0.3646786377100779, 5.231746799736418, -0.8872941829006613, -1.2667450524908894, -0.7085674651343723, 0.6291095032749097, 2.1755854465258038])
(β = [3.5305064448373233, 0.48816493389705795], α = [-4.157440107778783, 11.522096709354116, -4.247009127558339, -4.0669296486037005, -3.7962478626608958, -3.2522713365104843, -1.4268135181980763])
(β = [4.553034118360961, 0.29303919379535226], α = [-4.778227544619665, 1.8899159249011892, -5.308375516687748, -5.432403655646659, -5.1791963076248635, -3.8935302407977925, -2.98315469485819])
⋮
(β = [5.689145213467929, 0.24877653431359764], α = [-6.30538170349988, -0.4201652397012483, -6.277204093516436, -6.052252626670219, -5.995663014264362, -5.872470706984473, -3.5971942096960547])
(β = [4.069983286726469, 0.6997548844232694], α = [-4.518476869077594, 9.952455234933097, -5.111009884746467, -5.2040309747577265, -4.700546166942033, -3.3729383844087977, -2.5900368094760626])
(β = [2.3193347423092403, 0.3540741661083513], α = [-2.389765977572349, 8.690020509223096, -3.345951299261018, -2.8042877347425725, -2.846234996858588, -1.5577392031449118, -0.40230156596675504])
(β = [1.9321838440305492, -0.02001143019217619], α = [-2.21659683539953, 9.71724329262789, -2.7882687185120325, -2.5263614424374072, -2.2363227738699116, -1.2052371766075038, 0.3777353231370431])
(β = [-2.1161935330414097, 0.7959093851807156], α = [1.4820105966243653, 9.294439951779358, 1.4061825972819755, 1.2746585261810242, 1.5361036414642935, 2.282388449345532, 3.537656624017777])
(β = [3.669314430450167, -0.06306644610609036], α = [-4.147635303204296, 3.3369933956759406, -4.4127330272205265, -4.422511205651178, -4.124515563752176, -2.885341882698689, -0.9158591100640103])
(β = [2.280433991815419, 0.9117305557116371], α = [-2.650083224042836, 12.603317772892204, -2.9032977596106004, -2.942186298789138, -2.600623558143786, -2.010136913287744, -0.6311534406684174])
(β = [2.003245353258863, 0.926217040399798], α = [-2.413574158644768, 8.34411037143406, -2.9231743507563928, -3.067115709977377, -2.9188150506668737, -1.728206841094448, -0.06603179632919637])
(β = [3.326379886991486, -0.14258946327630564], α = [-4.433593144273074, 5.7658390843735585, -3.8533162910382037, -3.7678734903798468, -3.6086328208071556, -2.7183218261231548, -0.6420130314349701])
Result rethinking
rethinking = "
mean sd 5.5% 94.5% n_eff Rhat
a[1] -0.74 0.27 -1.19 -0.31 2899 1
a[2] 10.77 5.20 4.60 20.45 1916 1
a[3] -1.05 0.28 -1.50 -0.62 3146 1
a[4] -1.05 0.28 -1.50 -0.61 3525 1
a[5] -0.73 0.28 -1.17 -0.28 3637 1
a[6] 0.22 0.27 -0.21 0.67 3496 1
a[7] 1.82 0.41 1.21 2.50 3202 1
bp 0.83 0.27 0.42 1.27 2070 1
bpC -0.13 0.31 -0.62 0.34 3430 1
";
"\n mean sd 5.5% 94.5% n_eff Rhat\na[1] -0.74 0.27 -1.19 -0.31 2899 1\na[2] 10.77 5.20 4.60 20.45 1916 1\na[3] -1.05 0.28 -1.50 -0.62 3146 1\na[4] -1.05 0.28 -1.50 -0.61 3525 1\na[5] -0.73 0.28 -1.17 -0.28 3637 1\na[6] 0.22 0.27 -0.21 0.67 3496 1\na[7] 1.82 0.41 1.21 2.50 3202 1\nbp 0.83 0.27 0.42 1.27 2070 1\nbpC -0.13 0.31 -0.62 0.34 3430 1\n"
Set varable names, this will be automated using θ
parameter_names = ["bp", "bpC"]
pooled_parameter_names = ["a[$i]" for i in 1:7];
7-element Array{String,1}:
"a[1]"
"a[2]"
"a[3]"
"a[4]"
"a[5]"
"a[6]"
"a[7]"
Create a3d
a3d = Array{Float64, 3}(undef, 3000, 9, 1);
for i in 1:3000
a3d[i, 1:2, 1] = values(posterior[i][1])
a3d[i, 3:9, 1] = values(posterior[i][2])
end
chns = MCMCChains.Chains(a3d,
vcat(parameter_names, pooled_parameter_names),
Dict(
:parameters => parameter_names,
:pooled => pooled_parameter_names
)
);
Object of type Chains, with data of type 3000×9×1 Array{Float64,3}
Iterations = 1:3000
Thinning interval = 1
Chains = 1
Samples per chain = 3000
pooled = a[1], a[2], a[3], a[4], a[5], a[6], a[7]
parameters = bp, bpC
2-element Array{ChainDataFrame,1}
Summary Statistics
. Omitted printing of 1 columns
│ Row │ parameters │ mean │ std │ naive_se │ mcse │ ess │
│ │ Symbol │ Float64 │ Float64 │ Float64 │ Float64 │ Any │
├─────┼────────────┼──────────┼──────────┼────────────┼────────────┼─────────┤
│ 1 │ bp │ 1.52165 │ 3.63396 │ 0.0663466 │ 0.0957052 │ 1134.66 │
│ 2 │ bpC │ 0.412055 │ 0.245361 │ 0.00447967 │ 0.00449436 │ 3000.0 │
Quantiles
│ Row │ parameters │ 2.5% │ 25.0% │ 50.0% │ 75.0% │ 97.5% │
│ │ Symbol │ Float64 │ Float64 │ Float64 │ Float64 │ Float64 │
├─────┼────────────┼────────────┼───────────┼──────────┼──────────┼──────────┤
│ 1 │ bp │ -5.75568 │ -0.894522 │ 1.52198 │ 3.91638 │ 8.75534 │
│ 2 │ bpC │ -0.0756923 │ 0.245969 │ 0.407616 │ 0.575444 │ 0.898349 │
Describe the chain
describe(chns)
2-element Array{ChainDataFrame,1}
Summary Statistics
. Omitted printing of 1 columns
│ Row │ parameters │ mean │ std │ naive_se │ mcse │ ess │
│ │ Symbol │ Float64 │ Float64 │ Float64 │ Float64 │ Any │
├─────┼────────────┼──────────┼──────────┼────────────┼────────────┼─────────┤
│ 1 │ bp │ 1.52165 │ 3.63396 │ 0.0663466 │ 0.0957052 │ 1134.66 │
│ 2 │ bpC │ 0.412055 │ 0.245361 │ 0.00447967 │ 0.00449436 │ 3000.0 │
Quantiles
│ Row │ parameters │ 2.5% │ 25.0% │ 50.0% │ 75.0% │ 97.5% │
│ │ Symbol │ Float64 │ Float64 │ Float64 │ Float64 │ Float64 │
├─────┼────────────┼────────────┼───────────┼──────────┼──────────┼──────────┤
│ 1 │ bp │ -5.75568 │ -0.894522 │ 1.52198 │ 3.91638 │ 8.75534 │
│ 2 │ bpC │ -0.0756923 │ 0.245969 │ 0.407616 │ 0.575444 │ 0.898349 │
Describe the chain
describe(chns, sections=[:pooled])
2-element Array{ChainDataFrame,1}
Summary Statistics
. Omitted printing of 1 columns
│ Row │ parameters │ mean │ std │ naive_se │ mcse │ ess │
│ │ Symbol │ Float64 │ Float64 │ Float64 │ Float64 │ Any │
├─────┼────────────┼──────────┼─────────┼───────────┼───────────┼─────────┤
│ 1 │ a[1] │ -1.97781 │ 3.64504 │ 0.066549 │ 0.0961998 │ 1147.36 │
│ 2 │ a[2] │ 9.85265 │ 5.92061 │ 0.108095 │ 0.174025 │ 972.633 │
│ 3 │ a[3] │ -2.27552 │ 3.64414 │ 0.0665326 │ 0.0954322 │ 1147.57 │
│ 4 │ a[4] │ -2.27588 │ 3.63597 │ 0.0663833 │ 0.0954781 │ 1137.75 │
│ 5 │ a[5] │ -1.97114 │ 3.64342 │ 0.0665194 │ 0.0955994 │ 1134.28 │
│ 6 │ a[6] │ -1.03457 │ 3.64655 │ 0.0665765 │ 0.0952521 │ 1139.12 │
│ 7 │ a[7] │ 0.53125 │ 3.6466 │ 0.0665775 │ 0.0972008 │ 1123.3 │
Quantiles
│ Row │ parameters │ 2.5% │ 25.0% │ 50.0% │ 75.0% │ 97.5% │
│ │ Symbol │ Float64 │ Float64 │ Float64 │ Float64 │ Float64 │
├─────┼────────────┼───────────┼──────────┼──────────┼──────────┼─────────┤
│ 1 │ a[1] │ -9.22579 │ -4.31753 │ -2.01708 │ 0.4579 │ 5.28433 │
│ 2 │ a[2] │ -0.251699 │ 5.69313 │ 9.31642 │ 13.447 │ 23.5678 │
│ 3 │ a[3] │ -9.5852 │ -4.63254 │ -2.27662 │ 0.160106 │ 4.97501 │
│ 4 │ a[4] │ -9.48227 │ -4.6551 │ -2.31992 │ 0.163695 │ 5.06924 │
│ 5 │ a[5] │ -9.2475 │ -4.31611 │ -1.98012 │ 0.503501 │ 5.3256 │
│ 6 │ a[6] │ -8.25119 │ -3.40936 │ -1.04996 │ 1.37006 │ 6.37548 │
│ 7 │ a[7] │ -6.7622 │ -1.83372 │ 0.499678 │ 2.92218 │ 7.83578 │
Plot the chain parameters
plot(chns)
End of m10.04d1.jl
This page was generated using Literate.jl.