using Distributions
using FillArrays
using StatsPlots
using LinearAlgebra
using Random
plotlyjs()
Plots.PlotlyJSBackend()
Random.seed!(3)
TaskLocalRNG()
# w는 합해서 1이 되어야 한다.
w = [0.5, 0.5]
μ = [-3.5, 0.5]
mixturemodel = MixtureModel([MvNormal(Fill(μₖ,2), I) for μₖ in μ], w)
MixtureModel{MvNormal{Float64, PDMats.ScalMat{Float64}, Fill{Float64, 1, Tuple{Base.OneTo{Int64}}}}}(K = 2)
components[1] (prior = 0.5000): MvNormal{Float64, PDMats.ScalMat{Float64}, Fill{Float64, 1, Tuple{Base.OneTo{Int64}}}}(
dim: 2
μ: Fill(-3.5, 2)
Σ: [1.0 0.0; 0.0 1.0]
)
components[2] (prior = 0.5000): MvNormal{Float64, PDMats.ScalMat{Float64}, Fill{Float64, 1, Tuple{Base.OneTo{Int64}}}}(
dim: 2
μ: Fill(0.5, 2)
Σ: [1.0 0.0; 0.0 1.0]
)
N = 6000
x = rand(mixturemodel, N)
density(x[1,:])
density!(x[2,:])
mixture_pdf(x) = pdf(mixturemodel,x)
mixture_pdf([[0.0,4.0],[-1,2]])
2-element Vector{Float64}:
0.00015362065909652362
0.00838740473741466
xs = [collect(t) for t in zip(x[1,:],x[2,:])];
xs_pdf = mixture_pdf(x);
scatter3d(x[1,:],x[2,:],xs_pdf; markersize=0.1, legend=false,
title="Synthetic Dataset", color=:red, size=(800,400) )
주어진 데이터 세트에서 그룹핑을 복구하고자 한다. 즉 혼합 가중치(weight)와 모수(parameter) $\mu_1$,$\mu_2$ 그리고 생성 가우시안 혼합 모델을 위한 각 데이터의 클러스터 할당을 추론 하고자 한다
$k$개의 클러스터로 구성된 데이터 $x_i\;(i=1,...,N)$은 다음 생성 프로세스에 따라 생성된다.
디리클레분포(Dirichlet distribution)는 베타분포의 확장판이라고 할 수 있다. 베타분포는 0과 1사이의 값을 가지는 단일(univariate) 확률변수의 베이지안 모형에 사용되고 디리클레분포는 0과 1사이의 사이의 값을 가지는 다변수(multivariate) 확률변수의 베이지안 모형에 사용된다.
let
x = rand(Dirichlet(2,1.0),1000000)
density(x[1,:])
density!(x[2,:])
end
w = rand(Dirichlet(2,1.0))
size(w) |> display
w |> display
rand(Categorical(w),5)
(2,)
2-element Vector{Float64}:
0.3645928706520138
0.6354071293479864
5-element Vector{Int64}:
2
2
1
1
2
rand(MvNormal(Zeros(2), I),1)
2×1 Matrix{Float64}:
1.816361101257206
-1.1862323182615306
using Turing
using DataFrames
@model function gaussian_mixture_model(x)
# Draw the parameters for each of the K=2 clusters from a standard normal distribution.
K = 2
# 각 축별로 값을 추출한다.
# 예) [-0.82 ,1.14]
# 각 축별로 추출한 값이 각각 중심이 아래와 같은 클러스터가 된다
# [-0.82, -0,82], [1.14, 1.14]
μ ~ MvNormal(Zeros(K), I)
# Draw the weights for the K clusters from a Dirichlet distribution with parameters αₖ = 1.
# 어떤 클러스를 뽑을지 가중치 결정
# ∑wᵢ = 1
w ~ Dirichlet(K, 1.0)
# Construct categorical distribution of assignments.
# 가중치에 따른 클러스트 선택
distribution_assignments = Categorical(w)
# Construct categorical distribution of assignments.
# D : dimension, N : data count
D, N = size(x)
# 데이터를 뽑을 각 클러스터의 분포 설정
distribution_clusters = [MvNormal(Fill(μₖ, D), I) for μₖ in μ]
# Draw assignments for each datum and generate it from the multivariate
# normal distribution.
k = Vector{Int}(undef,N)
for i in 1:N
# 클러스터 선택
k[i] ~ distribution_assignments
# x는 선택된 클러스터의 분포를 따른다.
x[:,i] ~ distribution_clusters[k[i]]
end
return k
end
gaussian_mixture_model (generic function with 2 methods)
N = 60
x = rand(mixturemodel, N)
model = gaussian_mixture_model(x)
DynamicPPL.Model{typeof(gaussian_mixture_model), (:x,), (), (), Tuple{Matrix{Float64}}, Tuple{}, DynamicPPL.DefaultContext}(gaussian_mixture_model, (x = [-1.995712209060883 -3.5535472667182626 … -4.4501339298974445 1.0566045473216428; -0.43691107556630615 -2.843174596703907 … -4.036749975745415 -0.65206523018811],), NamedTuple(), DynamicPPL.DefaultContext())
매개변수 𝜇 및 𝑤와 할당 𝑘 의 사후 분포 근사치를 얻기 위해 MCMC 시뮬레이션을 실행합니다. 이산 파라미터(할당 𝑘)에는 입자 깁스 샘플러를, 연속 파라미터(𝜇 및 𝑤)에는 해밀토니온 몬테카를로 샘플러를 결합한 깁스 샘플러를 사용합니다. 멀티 스레딩을 사용하여 여러 체인을 병렬로 생성합니다.
g_sampler = Gibbs(PG(100,:k), HMC(0.05, 10, :μ, :w))
n_samples = 100
n_chains = 3
chains = sample(model, g_sampler, MCMCThreads(), n_samples, n_chains)
Sampling (3 threads): 100%|█████████████████████████████| Time: 0:00:04
Chains MCMC chain (100×65×3 Array{Float64, 3}):
Iterations = 1:1:100
Number of chains = 3
Samples per chain = 100
Wall duration = 320.95 seconds
Compute duration = 957.58 seconds
parameters = μ[1], μ[2], w[1], w[2], k[1], k[2], k[3], k[4], k[5], k[6], k[7], k[8], k[9], k[10], k[11], k[12], k[13], k[14], k[15], k[16], k[17], k[18], k[19], k[20], k[21], k[22], k[23], k[24], k[25], k[26], k[27], k[28], k[29], k[30], k[31], k[32], k[33], k[34], k[35], k[36], k[37], k[38], k[39], k[40], k[41], k[42], k[43], k[44], k[45], k[46], k[47], k[48], k[49], k[50], k[51], k[52], k[53], k[54], k[55], k[56], k[57], k[58], k[59], k[60]
internals = lp
Summary Statistics
parameters mean std mcse ess_bulk ess_tail ⋯
Symbol Float64 Float64 Float64 Float64 Float64 ⋯
μ[1] -0.8084 1.9205 1.0593 5.1302 63.4218 ⋯
μ[2] -2.1337 1.9045 1.0443 5.1449 93.6059 ⋯
w[1] 0.4946 0.0789 0.0074 236.3241 127.7179 ⋯
w[2] 0.5054 0.0789 0.0074 236.3241 127.7179 ⋯
k[1] 1.4133 0.4933 0.2142 5.3004 NaN ⋯
k[2] 1.6667 0.4722 0.2615 3.2609 NaN ⋯
k[3] 1.3600 0.4808 0.2505 3.6839 NaN ⋯
k[4] 1.6667 0.4722 0.2615 3.2609 NaN ⋯
k[5] 1.3467 0.4767 0.2578 3.4192 NaN ⋯
k[6] 1.6667 0.4722 0.2615 3.2609 NaN ⋯
k[7] 1.6667 0.4722 0.2615 3.2609 NaN ⋯
k[8] 1.3333 0.4722 0.2615 3.2609 NaN 15752611252 ⋯
k[9] 1.6667 0.4722 0.2615 3.2609 NaN ⋯
k[10] 1.6667 0.4722 0.2615 3.2609 NaN ⋯
k[11] 1.6667 0.4722 0.2615 3.2609 NaN ⋯
k[12] 1.6667 0.4722 0.2615 3.2609 NaN ⋯
k[13] 1.3433 0.4756 0.2587 3.3812 NaN ⋯
k[14] 1.6667 0.4722 0.2615 3.2609 NaN ⋯
k[15] 1.3433 0.4756 0.2553 3.4714 NaN ⋯
k[16] 1.3433 0.4756 0.2587 3.3812 NaN ⋯
k[17] 1.6667 0.4722 0.2615 3.2609 NaN ⋯
k[18] 1.6600 0.4745 0.2594 3.3450 NaN ⋯
k[19] 1.6667 0.4722 0.2615 3.2609 NaN ⋯
⋮ ⋮ ⋮ ⋮ ⋮ ⋮ ⋮ ⋱
2 columns and 41 rows omitted
Quantiles
parameters 2.5% 25.0% 50.0% 75.0% 97.5%
Symbol Float64 Float64 Float64 Float64 Float64
μ[1] -3.7010 -3.4167 0.4261 0.5800 0.8828
μ[2] -3.7804 -3.5509 -3.3976 0.4003 0.7006
w[1] 0.3701 0.4606 0.4989 0.5437 0.6075
w[2] 0.3925 0.4563 0.5011 0.5394 0.6299
k[1] 1.0000 1.0000 1.0000 2.0000 2.0000
k[2] 1.0000 1.0000 2.0000 2.0000 2.0000
k[3] 1.0000 1.0000 1.0000 2.0000 2.0000
k[4] 1.0000 1.0000 2.0000 2.0000 2.0000
k[5] 1.0000 1.0000 1.0000 2.0000 2.0000
k[6] 1.0000 1.0000 2.0000 2.0000 2.0000
k[7] 1.0000 1.0000 2.0000 2.0000 2.0000
k[8] 1.0000 1.0000 1.0000 2.0000 2.0000
k[9] 1.0000 1.0000 2.0000 2.0000 2.0000
k[10] 1.0000 1.0000 2.0000 2.0000 2.0000
k[11] 1.0000 1.0000 2.0000 2.0000 2.0000
k[12] 1.0000 1.0000 2.0000 2.0000 2.0000
k[13] 1.0000 1.0000 1.0000 2.0000 2.0000
k[14] 1.0000 1.0000 2.0000 2.0000 2.0000
k[15] 1.0000 1.0000 1.0000 2.0000 2.0000
k[16] 1.0000 1.0000 1.0000 2.0000 2.0000
k[17] 1.0000 1.0000 2.0000 2.0000 2.0000
k[18] 1.0000 1.0000 2.0000 2.0000 2.0000
k[19] 1.0000 1.0000 2.0000 2.0000 2.0000
⋮ ⋮ ⋮ ⋮ ⋮ ⋮
41 rows omitted
df = DataFrame(chains)
| Row | iteration | chain | μ[1] | μ[2] | w[1] | w[2] | k[1] | k[2] | k[3] | k[4] | k[5] | k[6] | k[7] | k[8] | k[9] | k[10] | k[11] | k[12] | k[13] | k[14] | k[15] | k[16] | k[17] | k[18] | k[19] | k[20] | k[21] | k[22] | k[23] | k[24] | k[25] | k[26] | k[27] | k[28] | k[29] | k[30] | k[31] | k[32] | k[33] | k[34] | k[35] | k[36] | k[37] | k[38] | k[39] | k[40] | k[41] | k[42] | k[43] | k[44] | k[45] | k[46] | k[47] | k[48] | k[49] | k[50] | k[51] | k[52] | k[53] | k[54] | k[55] | k[56] | k[57] | k[58] | k[59] | k[60] | lp |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| Int64 | Int64 | Float64 | Float64 | Float64 | Float64 | Float64 | Float64 | Float64 | Float64 | Float64 | Float64 | Float64 | Float64 | Float64 | Float64 | Float64 | Float64 | Float64 | Float64 | Float64 | Float64 | Float64 | Float64 | Float64 | Float64 | Float64 | Float64 | Float64 | Float64 | Float64 | Float64 | Float64 | Float64 | Float64 | Float64 | Float64 | Float64 | Float64 | Float64 | Float64 | Float64 | Float64 | Float64 | Float64 | Float64 | Float64 | Float64 | Float64 | Float64 | Float64 | Float64 | Float64 | Float64 | Float64 | Float64 | Float64 | Float64 | Float64 | Float64 | Float64 | Float64 | Float64 | Float64 | Float64 | Float64 | Float64 | |
| 1 | 1 | 1 | 0.806586 | -1.46989 | 0.16277 | 0.83723 | 2.0 | 2.0 | 2.0 | 2.0 | 1.0 | 2.0 | 2.0 | 1.0 | 2.0 | 2.0 | 2.0 | 2.0 | 1.0 | 2.0 | 2.0 | 1.0 | 2.0 | 2.0 | 2.0 | 2.0 | 2.0 | 1.0 | 2.0 | 2.0 | 2.0 | 2.0 | 2.0 | 1.0 | 2.0 | 2.0 | 2.0 | 1.0 | 1.0 | 2.0 | 2.0 | 2.0 | 2.0 | 2.0 | 2.0 | 1.0 | 2.0 | 2.0 | 2.0 | 2.0 | 2.0 | 2.0 | 2.0 | 1.0 | 2.0 | 2.0 | 2.0 | 1.0 | 2.0 | 2.0 | 2.0 | 2.0 | 2.0 | 2.0 | 2.0 | 1.0 | -399.401 |
| 2 | 2 | 1 | 1.00127 | -3.0938 | 0.437076 | 0.562924 | 2.0 | 2.0 | 2.0 | 2.0 | 1.0 | 2.0 | 2.0 | 1.0 | 2.0 | 2.0 | 2.0 | 2.0 | 1.0 | 2.0 | 1.0 | 1.0 | 2.0 | 2.0 | 2.0 | 1.0 | 2.0 | 1.0 | 2.0 | 1.0 | 2.0 | 2.0 | 2.0 | 1.0 | 2.0 | 2.0 | 2.0 | 1.0 | 1.0 | 2.0 | 1.0 | 2.0 | 2.0 | 1.0 | 2.0 | 2.0 | 1.0 | 2.0 | 2.0 | 2.0 | 2.0 | 1.0 | 2.0 | 1.0 | 2.0 | 2.0 | 2.0 | 1.0 | 2.0 | 2.0 | 1.0 | 1.0 | 1.0 | 2.0 | 2.0 | 1.0 | -296.311 |
| 3 | 3 | 1 | 1.00127 | -3.0938 | 0.437076 | 0.562924 | 2.0 | 2.0 | 2.0 | 2.0 | 1.0 | 2.0 | 2.0 | 1.0 | 2.0 | 2.0 | 2.0 | 2.0 | 1.0 | 2.0 | 1.0 | 1.0 | 2.0 | 2.0 | 2.0 | 1.0 | 2.0 | 1.0 | 2.0 | 1.0 | 2.0 | 2.0 | 2.0 | 1.0 | 2.0 | 2.0 | 2.0 | 1.0 | 1.0 | 1.0 | 1.0 | 2.0 | 2.0 | 1.0 | 2.0 | 1.0 | 1.0 | 1.0 | 2.0 | 1.0 | 2.0 | 1.0 | 2.0 | 1.0 | 1.0 | 2.0 | 2.0 | 1.0 | 2.0 | 2.0 | 1.0 | 1.0 | 1.0 | 1.0 | 2.0 | 1.0 | -252.515 |
| 4 | 4 | 1 | 0.460392 | -3.59038 | 0.418698 | 0.581302 | 2.0 | 2.0 | 2.0 | 2.0 | 1.0 | 2.0 | 2.0 | 1.0 | 2.0 | 2.0 | 2.0 | 2.0 | 1.0 | 2.0 | 1.0 | 1.0 | 2.0 | 2.0 | 2.0 | 1.0 | 2.0 | 1.0 | 2.0 | 1.0 | 2.0 | 2.0 | 2.0 | 1.0 | 2.0 | 2.0 | 2.0 | 1.0 | 1.0 | 1.0 | 1.0 | 2.0 | 2.0 | 1.0 | 2.0 | 1.0 | 1.0 | 1.0 | 2.0 | 1.0 | 2.0 | 1.0 | 2.0 | 1.0 | 1.0 | 2.0 | 2.0 | 1.0 | 2.0 | 2.0 | 1.0 | 1.0 | 1.0 | 1.0 | 2.0 | 1.0 | -254.984 |
| 5 | 5 | 1 | 0.890698 | -2.82879 | 0.469583 | 0.530417 | 2.0 | 2.0 | 2.0 | 2.0 | 1.0 | 2.0 | 2.0 | 1.0 | 2.0 | 2.0 | 2.0 | 2.0 | 1.0 | 2.0 | 1.0 | 1.0 | 2.0 | 2.0 | 2.0 | 1.0 | 2.0 | 1.0 | 2.0 | 1.0 | 2.0 | 2.0 | 2.0 | 1.0 | 2.0 | 2.0 | 2.0 | 1.0 | 1.0 | 1.0 | 1.0 | 2.0 | 2.0 | 1.0 | 2.0 | 1.0 | 1.0 | 1.0 | 2.0 | 1.0 | 2.0 | 1.0 | 1.0 | 1.0 | 2.0 | 2.0 | 2.0 | 1.0 | 2.0 | 2.0 | 1.0 | 1.0 | 1.0 | 1.0 | 2.0 | 1.0 | -261.595 |
| 6 | 6 | 1 | 0.352181 | -3.78578 | 0.534744 | 0.465256 | 2.0 | 2.0 | 1.0 | 2.0 | 1.0 | 2.0 | 2.0 | 1.0 | 2.0 | 2.0 | 2.0 | 2.0 | 1.0 | 2.0 | 1.0 | 1.0 | 2.0 | 2.0 | 2.0 | 1.0 | 2.0 | 1.0 | 2.0 | 1.0 | 2.0 | 2.0 | 1.0 | 1.0 | 2.0 | 2.0 | 2.0 | 1.0 | 1.0 | 1.0 | 1.0 | 2.0 | 2.0 | 1.0 | 2.0 | 1.0 | 1.0 | 1.0 | 2.0 | 1.0 | 2.0 | 1.0 | 2.0 | 1.0 | 1.0 | 2.0 | 2.0 | 1.0 | 2.0 | 2.0 | 1.0 | 1.0 | 1.0 | 1.0 | 2.0 | 1.0 | -238.056 |
| 7 | 7 | 1 | 0.591203 | -3.29433 | 0.426217 | 0.573783 | 2.0 | 2.0 | 1.0 | 2.0 | 1.0 | 2.0 | 2.0 | 1.0 | 2.0 | 2.0 | 2.0 | 2.0 | 1.0 | 2.0 | 1.0 | 1.0 | 2.0 | 2.0 | 2.0 | 1.0 | 2.0 | 1.0 | 2.0 | 1.0 | 2.0 | 2.0 | 1.0 | 1.0 | 2.0 | 2.0 | 2.0 | 1.0 | 1.0 | 1.0 | 1.0 | 2.0 | 2.0 | 1.0 | 2.0 | 1.0 | 1.0 | 1.0 | 2.0 | 1.0 | 2.0 | 1.0 | 1.0 | 1.0 | 1.0 | 1.0 | 2.0 | 1.0 | 2.0 | 2.0 | 1.0 | 1.0 | 1.0 | 1.0 | 2.0 | 1.0 | -240.432 |
| 8 | 8 | 1 | 0.741131 | -3.54489 | 0.443792 | 0.556208 | 2.0 | 2.0 | 1.0 | 2.0 | 1.0 | 2.0 | 2.0 | 1.0 | 2.0 | 2.0 | 2.0 | 2.0 | 1.0 | 2.0 | 1.0 | 1.0 | 2.0 | 2.0 | 2.0 | 1.0 | 2.0 | 1.0 | 2.0 | 1.0 | 2.0 | 2.0 | 1.0 | 1.0 | 2.0 | 2.0 | 2.0 | 1.0 | 1.0 | 1.0 | 1.0 | 2.0 | 2.0 | 1.0 | 2.0 | 1.0 | 1.0 | 1.0 | 2.0 | 1.0 | 2.0 | 1.0 | 2.0 | 1.0 | 1.0 | 2.0 | 2.0 | 1.0 | 2.0 | 2.0 | 1.0 | 1.0 | 1.0 | 1.0 | 2.0 | 1.0 | -232.511 |
| 9 | 9 | 1 | 0.651008 | -3.44931 | 0.477932 | 0.522068 | 2.0 | 2.0 | 1.0 | 2.0 | 1.0 | 2.0 | 2.0 | 1.0 | 2.0 | 2.0 | 2.0 | 2.0 | 1.0 | 2.0 | 1.0 | 1.0 | 2.0 | 2.0 | 2.0 | 1.0 | 2.0 | 1.0 | 2.0 | 1.0 | 2.0 | 2.0 | 1.0 | 1.0 | 2.0 | 2.0 | 2.0 | 1.0 | 1.0 | 1.0 | 1.0 | 2.0 | 2.0 | 1.0 | 2.0 | 1.0 | 1.0 | 1.0 | 2.0 | 1.0 | 2.0 | 1.0 | 2.0 | 1.0 | 1.0 | 2.0 | 2.0 | 1.0 | 2.0 | 2.0 | 1.0 | 1.0 | 1.0 | 1.0 | 2.0 | 1.0 | -231.401 |
| 10 | 10 | 1 | 0.359191 | -3.49657 | 0.520304 | 0.479696 | 2.0 | 2.0 | 1.0 | 2.0 | 1.0 | 2.0 | 2.0 | 1.0 | 2.0 | 2.0 | 2.0 | 2.0 | 1.0 | 2.0 | 1.0 | 1.0 | 2.0 | 2.0 | 2.0 | 1.0 | 2.0 | 1.0 | 2.0 | 1.0 | 2.0 | 2.0 | 1.0 | 1.0 | 2.0 | 2.0 | 2.0 | 1.0 | 1.0 | 1.0 | 1.0 | 2.0 | 2.0 | 1.0 | 2.0 | 1.0 | 1.0 | 1.0 | 2.0 | 1.0 | 2.0 | 1.0 | 1.0 | 1.0 | 1.0 | 2.0 | 2.0 | 1.0 | 2.0 | 2.0 | 1.0 | 1.0 | 1.0 | 1.0 | 2.0 | 1.0 | -233.431 |
| 11 | 11 | 1 | 0.764042 | -3.32044 | 0.375132 | 0.624868 | 1.0 | 2.0 | 1.0 | 2.0 | 1.0 | 2.0 | 2.0 | 1.0 | 2.0 | 2.0 | 2.0 | 2.0 | 1.0 | 2.0 | 1.0 | 1.0 | 2.0 | 2.0 | 2.0 | 1.0 | 2.0 | 1.0 | 2.0 | 1.0 | 2.0 | 2.0 | 1.0 | 1.0 | 2.0 | 2.0 | 2.0 | 1.0 | 1.0 | 1.0 | 1.0 | 2.0 | 2.0 | 1.0 | 2.0 | 1.0 | 1.0 | 1.0 | 2.0 | 1.0 | 2.0 | 1.0 | 2.0 | 1.0 | 1.0 | 2.0 | 2.0 | 1.0 | 2.0 | 2.0 | 1.0 | 1.0 | 1.0 | 1.0 | 2.0 | 1.0 | -233.765 |
| 12 | 12 | 1 | 0.483242 | -3.42788 | 0.481913 | 0.518087 | 2.0 | 2.0 | 1.0 | 2.0 | 1.0 | 2.0 | 2.0 | 1.0 | 2.0 | 2.0 | 2.0 | 2.0 | 1.0 | 2.0 | 1.0 | 1.0 | 2.0 | 2.0 | 2.0 | 1.0 | 2.0 | 1.0 | 2.0 | 1.0 | 2.0 | 2.0 | 1.0 | 1.0 | 2.0 | 2.0 | 2.0 | 1.0 | 1.0 | 1.0 | 1.0 | 2.0 | 2.0 | 1.0 | 2.0 | 1.0 | 1.0 | 1.0 | 2.0 | 1.0 | 2.0 | 1.0 | 2.0 | 1.0 | 1.0 | 2.0 | 2.0 | 1.0 | 2.0 | 2.0 | 1.0 | 1.0 | 1.0 | 1.0 | 2.0 | 1.0 | -231.863 |
| 13 | 13 | 1 | 0.542634 | -3.42531 | 0.446496 | 0.553504 | 2.0 | 2.0 | 1.0 | 2.0 | 1.0 | 2.0 | 2.0 | 1.0 | 2.0 | 2.0 | 2.0 | 2.0 | 1.0 | 2.0 | 1.0 | 1.0 | 2.0 | 2.0 | 2.0 | 1.0 | 2.0 | 1.0 | 2.0 | 1.0 | 2.0 | 2.0 | 1.0 | 1.0 | 2.0 | 2.0 | 2.0 | 1.0 | 1.0 | 1.0 | 1.0 | 2.0 | 2.0 | 1.0 | 2.0 | 1.0 | 1.0 | 1.0 | 2.0 | 1.0 | 2.0 | 1.0 | 1.0 | 1.0 | 1.0 | 2.0 | 2.0 | 1.0 | 2.0 | 2.0 | 1.0 | 1.0 | 1.0 | 1.0 | 2.0 | 1.0 | -232.734 |
| ⋮ | ⋮ | ⋮ | ⋮ | ⋮ | ⋮ | ⋮ | ⋮ | ⋮ | ⋮ | ⋮ | ⋮ | ⋮ | ⋮ | ⋮ | ⋮ | ⋮ | ⋮ | ⋮ | ⋮ | ⋮ | ⋮ | ⋮ | ⋮ | ⋮ | ⋮ | ⋮ | ⋮ | ⋮ | ⋮ | ⋮ | ⋮ | ⋮ | ⋮ | ⋮ | ⋮ | ⋮ | ⋮ | ⋮ | ⋮ | ⋮ | ⋮ | ⋮ | ⋮ | ⋮ | ⋮ | ⋮ | ⋮ | ⋮ | ⋮ | ⋮ | ⋮ | ⋮ | ⋮ | ⋮ | ⋮ | ⋮ | ⋮ | ⋮ | ⋮ | ⋮ | ⋮ | ⋮ | ⋮ | ⋮ | ⋮ | ⋮ | ⋮ |
| 289 | 89 | 3 | 0.431869 | -3.46359 | 0.51884 | 0.48116 | 1.0 | 2.0 | 1.0 | 2.0 | 1.0 | 2.0 | 2.0 | 1.0 | 2.0 | 2.0 | 2.0 | 2.0 | 1.0 | 2.0 | 1.0 | 1.0 | 2.0 | 2.0 | 2.0 | 1.0 | 2.0 | 1.0 | 2.0 | 1.0 | 2.0 | 2.0 | 1.0 | 1.0 | 2.0 | 2.0 | 2.0 | 1.0 | 1.0 | 1.0 | 1.0 | 2.0 | 2.0 | 1.0 | 2.0 | 1.0 | 1.0 | 1.0 | 2.0 | 1.0 | 2.0 | 1.0 | 1.0 | 1.0 | 1.0 | 2.0 | 2.0 | 1.0 | 2.0 | 2.0 | 1.0 | 1.0 | 1.0 | 1.0 | 2.0 | 1.0 | -230.338 |
| 290 | 90 | 3 | 0.382087 | -3.5937 | 0.374334 | 0.625666 | 1.0 | 2.0 | 1.0 | 2.0 | 1.0 | 2.0 | 2.0 | 1.0 | 2.0 | 2.0 | 2.0 | 2.0 | 1.0 | 2.0 | 1.0 | 1.0 | 2.0 | 2.0 | 2.0 | 1.0 | 2.0 | 1.0 | 2.0 | 1.0 | 2.0 | 2.0 | 1.0 | 1.0 | 2.0 | 2.0 | 2.0 | 1.0 | 1.0 | 1.0 | 1.0 | 2.0 | 2.0 | 1.0 | 2.0 | 1.0 | 1.0 | 1.0 | 2.0 | 1.0 | 2.0 | 1.0 | 1.0 | 1.0 | 1.0 | 2.0 | 2.0 | 1.0 | 2.0 | 2.0 | 1.0 | 1.0 | 1.0 | 1.0 | 2.0 | 1.0 | -232.983 |
| 291 | 91 | 3 | 0.382087 | -3.5937 | 0.374334 | 0.625666 | 1.0 | 2.0 | 1.0 | 2.0 | 1.0 | 2.0 | 2.0 | 1.0 | 2.0 | 2.0 | 2.0 | 2.0 | 1.0 | 2.0 | 1.0 | 1.0 | 2.0 | 2.0 | 2.0 | 1.0 | 2.0 | 1.0 | 2.0 | 1.0 | 2.0 | 2.0 | 1.0 | 1.0 | 2.0 | 2.0 | 2.0 | 1.0 | 1.0 | 1.0 | 1.0 | 2.0 | 2.0 | 1.0 | 2.0 | 1.0 | 1.0 | 1.0 | 2.0 | 1.0 | 2.0 | 1.0 | 1.0 | 1.0 | 1.0 | 2.0 | 2.0 | 1.0 | 2.0 | 2.0 | 1.0 | 1.0 | 1.0 | 1.0 | 2.0 | 1.0 | -232.983 |
| 292 | 92 | 3 | 0.51947 | -3.60965 | 0.564478 | 0.435522 | 1.0 | 2.0 | 1.0 | 2.0 | 1.0 | 2.0 | 2.0 | 1.0 | 2.0 | 2.0 | 2.0 | 2.0 | 1.0 | 2.0 | 1.0 | 1.0 | 2.0 | 2.0 | 2.0 | 1.0 | 2.0 | 1.0 | 2.0 | 1.0 | 2.0 | 2.0 | 1.0 | 1.0 | 2.0 | 2.0 | 2.0 | 1.0 | 1.0 | 1.0 | 1.0 | 2.0 | 2.0 | 1.0 | 2.0 | 1.0 | 1.0 | 1.0 | 2.0 | 1.0 | 2.0 | 1.0 | 1.0 | 1.0 | 1.0 | 2.0 | 2.0 | 1.0 | 2.0 | 2.0 | 1.0 | 1.0 | 1.0 | 1.0 | 2.0 | 1.0 | -230.41 |
| 293 | 93 | 3 | 0.46369 | -3.47764 | 0.515613 | 0.484387 | 1.0 | 2.0 | 1.0 | 2.0 | 1.0 | 2.0 | 2.0 | 1.0 | 2.0 | 2.0 | 2.0 | 2.0 | 1.0 | 2.0 | 1.0 | 1.0 | 2.0 | 2.0 | 2.0 | 1.0 | 2.0 | 1.0 | 2.0 | 1.0 | 2.0 | 2.0 | 1.0 | 1.0 | 2.0 | 2.0 | 2.0 | 1.0 | 1.0 | 1.0 | 1.0 | 2.0 | 2.0 | 1.0 | 2.0 | 1.0 | 1.0 | 1.0 | 2.0 | 1.0 | 2.0 | 1.0 | 1.0 | 1.0 | 1.0 | 2.0 | 2.0 | 1.0 | 2.0 | 2.0 | 1.0 | 1.0 | 1.0 | 1.0 | 2.0 | 1.0 | -230.184 |
| 294 | 94 | 3 | 0.362756 | -3.6061 | 0.544 | 0.456 | 1.0 | 2.0 | 1.0 | 2.0 | 1.0 | 2.0 | 2.0 | 1.0 | 2.0 | 2.0 | 2.0 | 2.0 | 1.0 | 2.0 | 1.0 | 1.0 | 2.0 | 2.0 | 2.0 | 1.0 | 2.0 | 1.0 | 2.0 | 1.0 | 2.0 | 2.0 | 1.0 | 1.0 | 2.0 | 2.0 | 2.0 | 1.0 | 1.0 | 1.0 | 1.0 | 2.0 | 2.0 | 1.0 | 2.0 | 1.0 | 1.0 | 1.0 | 2.0 | 1.0 | 2.0 | 1.0 | 1.0 | 1.0 | 1.0 | 2.0 | 2.0 | 1.0 | 2.0 | 2.0 | 1.0 | 1.0 | 1.0 | 1.0 | 2.0 | 1.0 | -230.689 |
| 295 | 95 | 3 | 0.631187 | -3.54805 | 0.472372 | 0.527628 | 1.0 | 2.0 | 1.0 | 2.0 | 1.0 | 2.0 | 2.0 | 1.0 | 2.0 | 2.0 | 2.0 | 2.0 | 1.0 | 2.0 | 1.0 | 1.0 | 2.0 | 2.0 | 2.0 | 1.0 | 2.0 | 1.0 | 2.0 | 1.0 | 2.0 | 2.0 | 1.0 | 1.0 | 2.0 | 2.0 | 2.0 | 1.0 | 1.0 | 1.0 | 1.0 | 2.0 | 2.0 | 1.0 | 2.0 | 1.0 | 1.0 | 1.0 | 2.0 | 1.0 | 2.0 | 1.0 | 1.0 | 1.0 | 1.0 | 2.0 | 2.0 | 1.0 | 2.0 | 2.0 | 1.0 | 1.0 | 1.0 | 1.0 | 2.0 | 1.0 | -230.857 |
| 296 | 96 | 3 | 0.483107 | -3.4399 | 0.567408 | 0.432592 | 1.0 | 2.0 | 1.0 | 2.0 | 1.0 | 2.0 | 2.0 | 1.0 | 2.0 | 2.0 | 2.0 | 2.0 | 1.0 | 2.0 | 1.0 | 1.0 | 2.0 | 2.0 | 2.0 | 1.0 | 2.0 | 1.0 | 2.0 | 1.0 | 2.0 | 2.0 | 1.0 | 1.0 | 2.0 | 2.0 | 2.0 | 1.0 | 1.0 | 1.0 | 1.0 | 2.0 | 2.0 | 1.0 | 2.0 | 1.0 | 1.0 | 1.0 | 2.0 | 1.0 | 2.0 | 1.0 | 2.0 | 1.0 | 1.0 | 2.0 | 2.0 | 1.0 | 2.0 | 2.0 | 1.0 | 1.0 | 1.0 | 1.0 | 2.0 | 1.0 | -230.423 |
| 297 | 97 | 3 | 0.490669 | -3.4831 | 0.402005 | 0.597995 | 1.0 | 2.0 | 1.0 | 2.0 | 1.0 | 2.0 | 2.0 | 1.0 | 2.0 | 2.0 | 2.0 | 2.0 | 1.0 | 2.0 | 1.0 | 1.0 | 2.0 | 2.0 | 2.0 | 1.0 | 2.0 | 1.0 | 2.0 | 1.0 | 2.0 | 2.0 | 1.0 | 1.0 | 2.0 | 2.0 | 2.0 | 1.0 | 1.0 | 1.0 | 1.0 | 2.0 | 2.0 | 1.0 | 2.0 | 1.0 | 1.0 | 1.0 | 2.0 | 1.0 | 2.0 | 1.0 | 2.0 | 1.0 | 1.0 | 2.0 | 2.0 | 1.0 | 2.0 | 2.0 | 1.0 | 1.0 | 1.0 | 1.0 | 2.0 | 1.0 | -230.969 |
| 298 | 98 | 3 | 0.596414 | -3.30452 | 0.440985 | 0.559015 | 1.0 | 2.0 | 1.0 | 2.0 | 1.0 | 2.0 | 2.0 | 1.0 | 2.0 | 2.0 | 2.0 | 2.0 | 1.0 | 2.0 | 1.0 | 1.0 | 2.0 | 2.0 | 2.0 | 1.0 | 2.0 | 1.0 | 2.0 | 1.0 | 2.0 | 2.0 | 1.0 | 1.0 | 2.0 | 2.0 | 2.0 | 1.0 | 1.0 | 1.0 | 1.0 | 2.0 | 2.0 | 1.0 | 2.0 | 1.0 | 1.0 | 1.0 | 2.0 | 1.0 | 2.0 | 1.0 | 2.0 | 1.0 | 1.0 | 2.0 | 2.0 | 1.0 | 2.0 | 2.0 | 1.0 | 1.0 | 1.0 | 1.0 | 2.0 | 1.0 | -231.115 |
| 299 | 99 | 3 | 0.769064 | -3.7695 | 0.511288 | 0.488712 | 1.0 | 2.0 | 1.0 | 2.0 | 1.0 | 2.0 | 2.0 | 1.0 | 2.0 | 2.0 | 2.0 | 2.0 | 1.0 | 2.0 | 1.0 | 1.0 | 2.0 | 2.0 | 2.0 | 1.0 | 2.0 | 1.0 | 2.0 | 1.0 | 2.0 | 2.0 | 1.0 | 1.0 | 2.0 | 2.0 | 2.0 | 1.0 | 1.0 | 1.0 | 1.0 | 2.0 | 2.0 | 1.0 | 2.0 | 1.0 | 1.0 | 1.0 | 2.0 | 1.0 | 2.0 | 1.0 | 2.0 | 1.0 | 1.0 | 2.0 | 2.0 | 1.0 | 2.0 | 2.0 | 1.0 | 1.0 | 1.0 | 1.0 | 2.0 | 1.0 | -233.438 |
| 300 | 100 | 3 | 0.175728 | -3.38995 | 0.46976 | 0.53024 | 1.0 | 2.0 | 1.0 | 2.0 | 1.0 | 2.0 | 2.0 | 1.0 | 2.0 | 2.0 | 2.0 | 2.0 | 1.0 | 2.0 | 1.0 | 1.0 | 2.0 | 2.0 | 2.0 | 1.0 | 2.0 | 1.0 | 2.0 | 1.0 | 2.0 | 2.0 | 1.0 | 1.0 | 2.0 | 2.0 | 2.0 | 1.0 | 1.0 | 1.0 | 1.0 | 2.0 | 2.0 | 1.0 | 2.0 | 1.0 | 1.0 | 1.0 | 2.0 | 1.0 | 2.0 | 1.0 | 2.0 | 1.0 | 1.0 | 2.0 | 2.0 | 1.0 | 2.0 | 2.0 | 1.0 | 1.0 | 1.0 | 1.0 | 2.0 | 1.0 | -234.459 |
let
# Verify that the output of the chain is as expected.
for i in MCMCChains.chains(chains)
# μ[1] and μ[2] can switch places, so we sort the values first.
chain = Array(chains[:,["μ[1]","μ[2]"],i])
μ_mean = vec(mean(chain; dims=1))
println(sort(μ_mean), μ)
@assert isapprox(sort(μ_mean), μ; rtol=0.1) "Difference between estimated mean of μ ($(sort(μ_mean))) and data-generating μ ($μ) unexpectedly large!"
end
end
[-3.4843836622735846, 0.5394186230220864][-3.5, 0.5] [-3.5072251084217667, 0.5232151482050279][-3.5, 0.5] [-3.4400456293350423, 0.5426133471604079][-3.5, 0.5]
let
Array(chains[:,["μ[1]","μ[2]"],1])
end
100×2 Matrix{Float64}:
0.806586 -1.46989
1.00127 -3.0938
1.00127 -3.0938
0.460392 -3.59038
0.890698 -2.82879
0.352181 -3.78578
0.591203 -3.29433
0.741131 -3.54489
0.651008 -3.44931
0.359191 -3.49657
0.764042 -3.32044
0.483242 -3.42788
0.542634 -3.42531
⋮
0.599547 -3.24166
0.361801 -3.70378
0.720165 -3.33989
0.283199 -3.72268
0.758261 -3.33467
0.225492 -3.59689
0.623473 -3.57775
0.273148 -3.61686
0.758842 -3.54824
0.355432 -3.47965
0.635855 -3.5731
0.383734 -3.55743
let
[mean.(eachcol(df[i:i+99,["μ[1]","μ[2]"]])) for i = 1:100:300]
end
3-element Vector{Vector{Float64}}:
[0.5394186230220864, -3.4843836622735846]
[-3.5072251084217667, 0.5232151482050279]
[0.5426133471604079, -3.4400456293350423]
plot(chains[["μ[1]","μ[2]"]]; colordim=:parameter,
legend=false, titlefontsize=9 )
plot(chains[["w[1]","w[2]"]]; colordim=:parameter, legend=false, titlefontsize=9)
chain = chains[:,:,1]
# Model with mean of samples as parameters.
μ_mean = [mean(chain, "μ[$i]") for i in 1:2]
w_mean = [mean(chain, "w[$i]") for i in 1:2]
mixturemodel_mean = MixtureModel([MvNormal(Fill(μₖ,2), I) for μₖ in μ_mean], w_mean)
contour(
range(-7.5, 3; length=1_000),
range(-6.5, 3; length=1_000),
(x,y) -> logpdf(mixturemodel_mean,[x, y]);
widen = false)
scatter!(x[1,:], x[2,:]; legend=false, title="Sythetic Dataset")