Tutorial – Gaussian Mixture Model¶
In [1]:
using Distributions
using FillArrays
using StatsPlots
using LinearAlgebra
using Random
#plotlyjs()
In [2]:
Random.seed!(3)
Out[2]:
TaskLocalRNG()
In [3]:
# w는 합해서 1이 되어야 한다.
w = [0.5, 0.5]
μ = [-3.5, 0.5]
mixturemodel = MixtureModel([MvNormal(Fill(μₖ,2), I) for μₖ in μ], w)
Out[3]:
MixtureModel{MvNormal{Float64, PDMats.ScalMat{Float64}, Fill{Float64, 1, Tuple{Base.OneTo{Int64}}}}}(K = 2) components[1] (prior = 0.5000): MvNormal{Float64, PDMats.ScalMat{Float64}, Fill{Float64, 1, Tuple{Base.OneTo{Int64}}}}( dim: 2 μ: Fill(-3.5, 2) Σ: [1.0 0.0; 0.0 1.0] ) components[2] (prior = 0.5000): MvNormal{Float64, PDMats.ScalMat{Float64}, Fill{Float64, 1, Tuple{Base.OneTo{Int64}}}}( dim: 2 μ: Fill(0.5, 2) Σ: [1.0 0.0; 0.0 1.0] )
In [4]:
N = 6000
x = rand(mixturemodel, N)
density(x[1,:])
density!(x[2,:])
Out[4]:
In [5]:
mixture_pdf(x) = pdf(mixturemodel,x)
mixture_pdf([[0.0,4.0],[-1,2]])
Out[5]:
2-element Vector{Float64}: 0.00015362065909652362 0.00838740473741466
In [6]:
xs = [collect(t) for t in zip(x[1,:],x[2,:])];
In [7]:
xs_pdf = mixture_pdf(x);
In [8]:
scatter3d(x[1,:],x[2,:],xs_pdf; markersize=0.1, legend=false,
title="Synthetic Dataset", color=:red, size=(800,400) )
Out[8]:
Gaussian Mixture Model in Turing¶
주어진 데이터 세트에서 그룹핑을 복구하고자 한다. 즉 혼합 가중치(weight)와 모수(parameter) $\mu_1$,$\mu_2$ 그리고 생성 가우시안 혼합 모델을 위한 각 데이터의 클러스터 할당을 추론 하고자 한다
$k$개의 클러스터로 구성된 데이터 $x_i\;(i=1,…,N)$은 다음 생성 프로세스에 따라 생성된다.
- 모델 파라미터 즉 $\mu_k$를 뽑는다. $\mu_k \sim \mathcal{N}(0,1)\;\;(k=1,…,K)$, $k$가 뽑힐 가중치 $w$는 $w \sim Dirichlet(\alpha_1,…,\alpha_k)$
- 모델 파라미터 설정이 끝났어면 관찰 데이터$(N)$개를 생성 하기 위해 클러스터를 선택 하고 그 클러스터 근처에 있는 관찰 데이터를 뽑는다.
- $z_i \sim Categorical(w),\;\;\;(i=1,…,N)$
- $x_i \sim \mathcal{N}([\mu_{z_i},\mu_{z_i}]^T,I),\;\;\;(i=1,…,N)$
디리클레분포(Dirichlet distribution)는 베타분포의 확장판이라고 할 수 있다. 베타분포는 0과 1사이의 값을 가지는 단일(univariate) 확률변수의 베이지안 모형에 사용되고 디리클레분포는 0과 1사이의 사이의 값을 가지는 다변수(multivariate) 확률변수의 베이지안 모형에 사용된다.
In [9]:
let
x = rand(Dirichlet(2,1.0),1000000)
density(x[1,:])
density!(x[2,:])
end
Out[9]:
In [25]:
w = rand(Dirichlet(2,1.0))
size(w) |> display
w |> display
rand(Categorical(w),5)
(2,)
2-element Vector{Float64}: 0.11466848154062723 0.8853315184593729
Out[25]:
5-element Vector{Int64}: 2 2 2 2 2
In [11]:
rand(MvNormal(Zeros(2), I),1)
Out[11]:
2×1 Matrix{Float64}: 0.4953993220164408 1.5884707065784054
In [12]:
using Turing
using DataFrames
In [13]:
@model function gaussian_mixture_model(x)
# Draw the parameters for each of the K=2 clusters from a standard normal distribution.
K = 2
# 각 축별로 값을 추출한다.
# 예) [-0.82 ,1.14]
# 각 축별로 추출한 값이 각각 중심이 아래와 같은 클러스터가 된다
# [-0.82, -0,82], [1.14, 1.14]
μ ~ MvNormal(Zeros(K), I)
# Draw the weights for the K clusters from a Dirichlet distribution with parameters αₖ = 1.
# 어떤 클러스를 뽑을지 가중치 결정
# ∑wᵢ = 1
w ~ Dirichlet(K, 1.0)
# Construct categorical distribution of assignments.
# 가중치에 따른 클러스트 선택
distribution_assignments = Categorical(w)
# Construct categorical distribution of assignments.
# D : dimension, N : data count
D, N = size(x)
# 데이터를 뽑을 각 클러스터의 분포 설정
distribution_clusters = [MvNormal(Fill(μₖ, D), I) for μₖ in μ]
# Draw assignments for each datum and generate it from the multivariate
# normal distribution.
k = Vector{Int}(undef,N)
for i in 1:N
# 클러스터 선택
k[i] ~ distribution_assignments
# x는 선택된 클러스터의 분포를 따른다.
x[:,i] ~ distribution_clusters[k[i]]
end
return k
end
Out[13]:
gaussian_mixture_model (generic function with 2 methods)
In [14]:
N = 60
x = rand(mixturemodel, N)
model = gaussian_mixture_model(x)
Out[14]:
DynamicPPL.Model{typeof(gaussian_mixture_model), (:x,), (), (), Tuple{Matrix{Float64}}, Tuple{}, DynamicPPL.DefaultContext}(gaussian_mixture_model, (x = [0.7815230657909114 0.606475608825794 … -2.7539368689652712 -3.332216606756868; 0.13160246698385814 3.2673330211511193 … -3.9413132284473344 -5.17386349552132],), NamedTuple(), DynamicPPL.DefaultContext())
매개변수 𝜇 및 𝑤와 할당 𝑘 의 사후 분포 근사치를 얻기 위해 MCMC 시뮬레이션을 실행합니다. 이산 파라미터(할당 𝑘)에는 입자 깁스 샘플러를, 연속 파라미터(𝜇 및 𝑤)에는 해밀토니온 몬테카를로 샘플러를 결합한 깁스 샘플러를 사용합니다. 멀티 스레딩을 사용하여 여러 체인을 병렬로 생성합니다.
In [15]:
g_sampler = Gibbs(PG(100,:k), HMC(0.05, 10, :μ, :w))
n_samples = 100
n_chains = 3
chains = sample(model, g_sampler, MCMCThreads(), n_samples, n_chains)
Out[15]:
Chains MCMC chain (100×65×3 Array{Float64, 3}): Iterations = 1:1:100 Number of chains = 3 Samples per chain = 100 Wall duration = 351.31 seconds Compute duration = 1049.8 seconds parameters = μ[1], μ[2], w[1], w[2], k[1], k[2], k[3], k[4], k[5], k[6], k[7], k[8], k[9], k[10], k[11], k[12], k[13], k[14], k[15], k[16], k[17], k[18], k[19], k[20], k[21], k[22], k[23], k[24], k[25], k[26], k[27], k[28], k[29], k[30], k[31], k[32], k[33], k[34], k[35], k[36], k[37], k[38], k[39], k[40], k[41], k[42], k[43], k[44], k[45], k[46], k[47], k[48], k[49], k[50], k[51], k[52], k[53], k[54], k[55], k[56], k[57], k[58], k[59], k[60] internals = lp Summary Statistics parameters mean std mcse ess_bulk ess_tail rhat e ⋯ Symbol Float64 Float64 Float64 Float64 Float64 Float64 ⋯ μ[1] -3.4739 0.1930 0.0133 329.3007 128.1291 1.0081 ⋯ μ[2] 0.5223 0.1642 0.0066 743.1364 179.2010 1.0083 ⋯ w[1] 0.5149 0.0676 0.0034 430.8876 254.7067 1.0052 ⋯ w[2] 0.4851 0.0676 0.0034 430.8876 254.7067 1.0052 ⋯ k[1] 1.9833 0.1282 0.0207 38.2172 NaN 1.0435 ⋯ k[2] 2.0000 0.0000 NaN NaN NaN NaN ⋯ k[3] 1.0000 0.0000 NaN NaN NaN NaN ⋯ k[4] 1.9833 0.1282 0.0207 38.2172 NaN 1.0435 ⋯ k[5] 1.0000 0.0000 NaN NaN NaN NaN ⋯ k[6] 1.9900 0.0997 0.0057 305.1682 NaN 0.9960 ⋯ k[7] 1.8367 0.3703 0.0457 65.6779 NaN 1.0037 ⋯ k[8] 1.0033 0.0577 0.0033 300.2402 300.2402 1.0000 ⋯ k[9] 1.9933 0.0815 0.0047 302.6742 NaN 0.9980 ⋯ k[10] 1.0000 0.0000 NaN NaN NaN NaN ⋯ k[11] 2.0000 0.0000 NaN NaN NaN NaN ⋯ k[12] 1.0000 0.0000 NaN NaN NaN NaN ⋯ k[13] 1.0000 0.0000 NaN NaN NaN NaN ⋯ k[14] 2.0000 0.0000 NaN NaN NaN NaN ⋯ k[15] 1.0000 0.0000 NaN NaN NaN NaN ⋯ k[16] 1.0000 0.0000 NaN NaN NaN NaN ⋯ k[17] 1.0033 0.0577 0.0033 300.2402 300.2402 1.0000 ⋯ k[18] 1.0000 0.0000 NaN NaN NaN NaN ⋯ k[19] 2.0000 0.0000 NaN NaN NaN NaN ⋯ ⋮ ⋮ ⋮ ⋮ ⋮ ⋮ ⋮ ⋱ 1 column and 41 rows omitted Quantiles parameters 2.5% 25.0% 50.0% 75.0% 97.5% Symbol Float64 Float64 Float64 Float64 Float64 μ[1] -3.7462 -3.5626 -3.4741 -3.4011 -3.1858 μ[2] 0.2385 0.4347 0.5179 0.6115 0.8446 w[1] 0.3809 0.4696 0.5120 0.5604 0.6474 w[2] 0.3526 0.4396 0.4880 0.5304 0.6191 k[1] 2.0000 2.0000 2.0000 2.0000 2.0000 k[2] 2.0000 2.0000 2.0000 2.0000 2.0000 k[3] 1.0000 1.0000 1.0000 1.0000 1.0000 k[4] 2.0000 2.0000 2.0000 2.0000 2.0000 k[5] 1.0000 1.0000 1.0000 1.0000 1.0000 k[6] 2.0000 2.0000 2.0000 2.0000 2.0000 k[7] 1.0000 2.0000 2.0000 2.0000 2.0000 k[8] 1.0000 1.0000 1.0000 1.0000 1.0000 k[9] 2.0000 2.0000 2.0000 2.0000 2.0000 k[10] 1.0000 1.0000 1.0000 1.0000 1.0000 k[11] 2.0000 2.0000 2.0000 2.0000 2.0000 k[12] 1.0000 1.0000 1.0000 1.0000 1.0000 k[13] 1.0000 1.0000 1.0000 1.0000 1.0000 k[14] 2.0000 2.0000 2.0000 2.0000 2.0000 k[15] 1.0000 1.0000 1.0000 1.0000 1.0000 k[16] 1.0000 1.0000 1.0000 1.0000 1.0000 k[17] 1.0000 1.0000 1.0000 1.0000 1.0000 k[18] 1.0000 1.0000 1.0000 1.0000 1.0000 k[19] 2.0000 2.0000 2.0000 2.0000 2.0000 ⋮ ⋮ ⋮ ⋮ ⋮ ⋮ 41 rows omitted
In [16]:
df = DataFrame(chains)
Out[16]:
300×67 DataFrame
275 rows omitted
Row | iteration | chain | μ[1] | μ[2] | w[1] | w[2] | k[1] | k[2] | k[3] | k[4] | k[5] | k[6] | k[7] | k[8] | k[9] | k[10] | k[11] | k[12] | k[13] | k[14] | k[15] | k[16] | k[17] | k[18] | k[19] | k[20] | k[21] | k[22] | k[23] | k[24] | k[25] | k[26] | k[27] | k[28] | k[29] | k[30] | k[31] | k[32] | k[33] | k[34] | k[35] | k[36] | k[37] | k[38] | k[39] | k[40] | k[41] | k[42] | k[43] | k[44] | k[45] | k[46] | k[47] | k[48] | k[49] | k[50] | k[51] | k[52] | k[53] | k[54] | k[55] | k[56] | k[57] | k[58] | k[59] | k[60] | lp |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
Int64 | Int64 | Float64 | Float64 | Float64 | Float64 | Float64 | Float64 | Float64 | Float64 | Float64 | Float64 | Float64 | Float64 | Float64 | Float64 | Float64 | Float64 | Float64 | Float64 | Float64 | Float64 | Float64 | Float64 | Float64 | Float64 | Float64 | Float64 | Float64 | Float64 | Float64 | Float64 | Float64 | Float64 | Float64 | Float64 | Float64 | Float64 | Float64 | Float64 | Float64 | Float64 | Float64 | Float64 | Float64 | Float64 | Float64 | Float64 | Float64 | Float64 | Float64 | Float64 | Float64 | Float64 | Float64 | Float64 | Float64 | Float64 | Float64 | Float64 | Float64 | Float64 | Float64 | Float64 | Float64 | Float64 | Float64 | |
1 | 1 | 1 | -1.82734 | 1.54074 | 0.773955 | 0.226045 | 1.0 | 2.0 | 1.0 | 1.0 | 1.0 | 1.0 | 1.0 | 1.0 | 1.0 | 1.0 | 2.0 | 1.0 | 1.0 | 2.0 | 1.0 | 1.0 | 1.0 | 1.0 | 2.0 | 1.0 | 1.0 | 1.0 | 1.0 | 1.0 | 1.0 | 1.0 | 1.0 | 1.0 | 1.0 | 2.0 | 1.0 | 1.0 | 1.0 | 2.0 | 1.0 | 1.0 | 1.0 | 2.0 | 2.0 | 1.0 | 1.0 | 1.0 | 1.0 | 2.0 | 1.0 | 1.0 | 2.0 | 1.0 | 1.0 | 1.0 | 1.0 | 1.0 | 1.0 | 2.0 | 1.0 | 1.0 | 1.0 | 2.0 | 1.0 | 1.0 | -360.777 |
2 | 2 | 1 | -3.08401 | 0.139212 | 0.586067 | 0.413933 | 1.0 | 2.0 | 1.0 | 1.0 | 1.0 | 2.0 | 1.0 | 1.0 | 2.0 | 1.0 | 2.0 | 1.0 | 1.0 | 2.0 | 1.0 | 1.0 | 1.0 | 1.0 | 2.0 | 1.0 | 2.0 | 1.0 | 1.0 | 1.0 | 1.0 | 2.0 | 1.0 | 2.0 | 1.0 | 2.0 | 1.0 | 1.0 | 1.0 | 2.0 | 1.0 | 1.0 | 1.0 | 2.0 | 2.0 | 1.0 | 2.0 | 1.0 | 1.0 | 2.0 | 1.0 | 2.0 | 2.0 | 2.0 | 1.0 | 1.0 | 1.0 | 2.0 | 1.0 | 2.0 | 1.0 | 1.0 | 1.0 | 2.0 | 1.0 | 1.0 | -300.726 |
3 | 3 | 1 | -3.11342 | 0.893437 | 0.629667 | 0.370333 | 1.0 | 2.0 | 1.0 | 1.0 | 1.0 | 2.0 | 1.0 | 1.0 | 2.0 | 1.0 | 2.0 | 1.0 | 1.0 | 2.0 | 1.0 | 1.0 | 1.0 | 1.0 | 2.0 | 1.0 | 2.0 | 1.0 | 1.0 | 1.0 | 1.0 | 2.0 | 1.0 | 2.0 | 1.0 | 2.0 | 1.0 | 1.0 | 2.0 | 2.0 | 1.0 | 1.0 | 1.0 | 2.0 | 2.0 | 2.0 | 2.0 | 1.0 | 1.0 | 2.0 | 1.0 | 2.0 | 2.0 | 2.0 | 1.0 | 2.0 | 1.0 | 2.0 | 2.0 | 2.0 | 2.0 | 1.0 | 1.0 | 2.0 | 1.0 | 1.0 | -275.341 |
4 | 4 | 1 | -3.24377 | 0.473691 | 0.665531 | 0.334469 | 1.0 | 2.0 | 1.0 | 1.0 | 1.0 | 2.0 | 1.0 | 1.0 | 2.0 | 1.0 | 2.0 | 1.0 | 1.0 | 2.0 | 1.0 | 1.0 | 1.0 | 1.0 | 2.0 | 1.0 | 2.0 | 1.0 | 1.0 | 1.0 | 1.0 | 2.0 | 1.0 | 2.0 | 1.0 | 2.0 | 1.0 | 1.0 | 2.0 | 2.0 | 1.0 | 1.0 | 1.0 | 2.0 | 2.0 | 2.0 | 2.0 | 1.0 | 1.0 | 2.0 | 1.0 | 2.0 | 2.0 | 2.0 | 1.0 | 2.0 | 1.0 | 2.0 | 1.0 | 2.0 | 2.0 | 1.0 | 1.0 | 2.0 | 1.0 | 1.0 | -272.02 |
5 | 5 | 1 | -3.07499 | 0.902173 | 0.53568 | 0.46432 | 1.0 | 2.0 | 1.0 | 1.0 | 1.0 | 2.0 | 1.0 | 1.0 | 2.0 | 1.0 | 2.0 | 1.0 | 1.0 | 2.0 | 1.0 | 1.0 | 1.0 | 1.0 | 2.0 | 1.0 | 2.0 | 1.0 | 1.0 | 1.0 | 1.0 | 2.0 | 1.0 | 2.0 | 1.0 | 2.0 | 1.0 | 1.0 | 2.0 | 2.0 | 1.0 | 1.0 | 1.0 | 2.0 | 2.0 | 2.0 | 2.0 | 1.0 | 1.0 | 2.0 | 1.0 | 2.0 | 2.0 | 2.0 | 1.0 | 2.0 | 1.0 | 2.0 | 1.0 | 2.0 | 2.0 | 1.0 | 1.0 | 2.0 | 1.0 | 1.0 | -271.323 |
6 | 6 | 1 | -3.47349 | 0.344424 | 0.561172 | 0.438828 | 2.0 | 2.0 | 1.0 | 2.0 | 1.0 | 2.0 | 1.0 | 1.0 | 2.0 | 1.0 | 2.0 | 1.0 | 1.0 | 2.0 | 1.0 | 1.0 | 1.0 | 1.0 | 2.0 | 1.0 | 2.0 | 2.0 | 1.0 | 1.0 | 1.0 | 2.0 | 1.0 | 2.0 | 1.0 | 2.0 | 1.0 | 1.0 | 2.0 | 2.0 | 1.0 | 1.0 | 1.0 | 2.0 | 2.0 | 2.0 | 2.0 | 1.0 | 1.0 | 2.0 | 1.0 | 2.0 | 2.0 | 2.0 | 1.0 | 2.0 | 1.0 | 2.0 | 1.0 | 2.0 | 2.0 | 1.0 | 1.0 | 2.0 | 1.0 | 1.0 | -236.109 |
7 | 7 | 1 | -3.53511 | 0.751037 | 0.527242 | 0.472758 | 2.0 | 2.0 | 1.0 | 2.0 | 1.0 | 2.0 | 2.0 | 1.0 | 2.0 | 1.0 | 2.0 | 1.0 | 1.0 | 2.0 | 1.0 | 1.0 | 1.0 | 1.0 | 2.0 | 1.0 | 2.0 | 2.0 | 1.0 | 1.0 | 1.0 | 2.0 | 1.0 | 2.0 | 1.0 | 2.0 | 1.0 | 1.0 | 2.0 | 2.0 | 1.0 | 1.0 | 1.0 | 2.0 | 2.0 | 2.0 | 2.0 | 1.0 | 1.0 | 2.0 | 1.0 | 2.0 | 2.0 | 2.0 | 1.0 | 2.0 | 1.0 | 2.0 | 1.0 | 2.0 | 2.0 | 1.0 | 1.0 | 2.0 | 1.0 | 1.0 | -233.713 |
8 | 8 | 1 | -3.3098 | 0.312821 | 0.550122 | 0.449878 | 2.0 | 2.0 | 1.0 | 2.0 | 1.0 | 2.0 | 2.0 | 1.0 | 2.0 | 1.0 | 2.0 | 1.0 | 1.0 | 2.0 | 1.0 | 1.0 | 1.0 | 1.0 | 2.0 | 1.0 | 2.0 | 2.0 | 1.0 | 1.0 | 1.0 | 2.0 | 1.0 | 2.0 | 1.0 | 2.0 | 1.0 | 1.0 | 2.0 | 2.0 | 1.0 | 1.0 | 1.0 | 2.0 | 2.0 | 2.0 | 2.0 | 1.0 | 1.0 | 2.0 | 1.0 | 2.0 | 2.0 | 2.0 | 1.0 | 2.0 | 1.0 | 2.0 | 1.0 | 2.0 | 2.0 | 1.0 | 1.0 | 2.0 | 1.0 | 1.0 | -234.722 |
9 | 9 | 1 | -3.61012 | 0.649756 | 0.401033 | 0.598967 | 2.0 | 2.0 | 1.0 | 2.0 | 1.0 | 2.0 | 2.0 | 1.0 | 2.0 | 1.0 | 2.0 | 1.0 | 1.0 | 2.0 | 1.0 | 1.0 | 1.0 | 1.0 | 2.0 | 1.0 | 2.0 | 2.0 | 1.0 | 1.0 | 1.0 | 2.0 | 1.0 | 2.0 | 1.0 | 2.0 | 1.0 | 1.0 | 2.0 | 2.0 | 1.0 | 1.0 | 1.0 | 2.0 | 2.0 | 2.0 | 2.0 | 1.0 | 1.0 | 2.0 | 1.0 | 2.0 | 2.0 | 2.0 | 1.0 | 2.0 | 1.0 | 2.0 | 2.0 | 2.0 | 2.0 | 1.0 | 1.0 | 2.0 | 1.0 | 1.0 | -235.133 |
10 | 10 | 1 | -3.40551 | 0.252875 | 0.565298 | 0.434702 | 2.0 | 2.0 | 1.0 | 2.0 | 1.0 | 2.0 | 2.0 | 1.0 | 2.0 | 1.0 | 2.0 | 1.0 | 1.0 | 2.0 | 1.0 | 1.0 | 1.0 | 1.0 | 2.0 | 1.0 | 2.0 | 2.0 | 1.0 | 1.0 | 1.0 | 2.0 | 1.0 | 2.0 | 1.0 | 2.0 | 1.0 | 1.0 | 2.0 | 2.0 | 1.0 | 1.0 | 1.0 | 2.0 | 2.0 | 2.0 | 2.0 | 1.0 | 1.0 | 2.0 | 1.0 | 2.0 | 2.0 | 2.0 | 1.0 | 2.0 | 1.0 | 2.0 | 2.0 | 2.0 | 2.0 | 1.0 | 1.0 | 2.0 | 1.0 | 1.0 | -235.228 |
11 | 11 | 1 | -3.34879 | 0.729625 | 0.491446 | 0.508554 | 2.0 | 2.0 | 1.0 | 2.0 | 1.0 | 2.0 | 2.0 | 1.0 | 2.0 | 1.0 | 2.0 | 1.0 | 1.0 | 2.0 | 1.0 | 1.0 | 1.0 | 1.0 | 2.0 | 1.0 | 2.0 | 2.0 | 1.0 | 1.0 | 1.0 | 2.0 | 1.0 | 2.0 | 1.0 | 2.0 | 1.0 | 1.0 | 2.0 | 2.0 | 1.0 | 1.0 | 1.0 | 2.0 | 2.0 | 2.0 | 2.0 | 1.0 | 1.0 | 2.0 | 1.0 | 2.0 | 2.0 | 2.0 | 1.0 | 2.0 | 1.0 | 2.0 | 2.0 | 2.0 | 2.0 | 1.0 | 1.0 | 2.0 | 1.0 | 1.0 | -235.685 |
12 | 12 | 1 | -3.54079 | 0.31537 | 0.51373 | 0.48627 | 2.0 | 2.0 | 1.0 | 2.0 | 1.0 | 2.0 | 2.0 | 1.0 | 2.0 | 1.0 | 2.0 | 1.0 | 1.0 | 2.0 | 1.0 | 1.0 | 1.0 | 1.0 | 2.0 | 1.0 | 2.0 | 2.0 | 1.0 | 1.0 | 1.0 | 2.0 | 1.0 | 2.0 | 1.0 | 2.0 | 1.0 | 1.0 | 2.0 | 2.0 | 1.0 | 1.0 | 1.0 | 2.0 | 2.0 | 2.0 | 2.0 | 1.0 | 1.0 | 2.0 | 1.0 | 2.0 | 2.0 | 2.0 | 1.0 | 2.0 | 1.0 | 2.0 | 1.0 | 2.0 | 2.0 | 1.0 | 1.0 | 2.0 | 1.0 | 1.0 | -233.997 |
13 | 13 | 1 | -3.52774 | 0.672473 | 0.559504 | 0.440496 | 2.0 | 2.0 | 1.0 | 2.0 | 1.0 | 2.0 | 2.0 | 1.0 | 2.0 | 1.0 | 2.0 | 1.0 | 1.0 | 2.0 | 1.0 | 1.0 | 1.0 | 1.0 | 2.0 | 1.0 | 2.0 | 2.0 | 1.0 | 1.0 | 1.0 | 2.0 | 1.0 | 2.0 | 1.0 | 2.0 | 1.0 | 1.0 | 2.0 | 2.0 | 1.0 | 1.0 | 1.0 | 2.0 | 2.0 | 2.0 | 2.0 | 1.0 | 1.0 | 2.0 | 1.0 | 2.0 | 2.0 | 2.0 | 1.0 | 2.0 | 1.0 | 2.0 | 2.0 | 2.0 | 2.0 | 1.0 | 1.0 | 2.0 | 1.0 | 1.0 | -234.377 |
⋮ | ⋮ | ⋮ | ⋮ | ⋮ | ⋮ | ⋮ | ⋮ | ⋮ | ⋮ | ⋮ | ⋮ | ⋮ | ⋮ | ⋮ | ⋮ | ⋮ | ⋮ | ⋮ | ⋮ | ⋮ | ⋮ | ⋮ | ⋮ | ⋮ | ⋮ | ⋮ | ⋮ | ⋮ | ⋮ | ⋮ | ⋮ | ⋮ | ⋮ | ⋮ | ⋮ | ⋮ | ⋮ | ⋮ | ⋮ | ⋮ | ⋮ | ⋮ | ⋮ | ⋮ | ⋮ | ⋮ | ⋮ | ⋮ | ⋮ | ⋮ | ⋮ | ⋮ | ⋮ | ⋮ | ⋮ | ⋮ | ⋮ | ⋮ | ⋮ | ⋮ | ⋮ | ⋮ | ⋮ | ⋮ | ⋮ | ⋮ | ⋮ |
289 | 89 | 3 | -3.50551 | 0.484045 | 0.316322 | 0.683678 | 2.0 | 2.0 | 1.0 | 2.0 | 1.0 | 2.0 | 2.0 | 1.0 | 2.0 | 1.0 | 2.0 | 1.0 | 1.0 | 2.0 | 1.0 | 1.0 | 1.0 | 1.0 | 2.0 | 1.0 | 2.0 | 2.0 | 1.0 | 1.0 | 1.0 | 2.0 | 1.0 | 2.0 | 1.0 | 2.0 | 1.0 | 1.0 | 2.0 | 2.0 | 1.0 | 1.0 | 1.0 | 2.0 | 2.0 | 2.0 | 2.0 | 1.0 | 1.0 | 2.0 | 1.0 | 2.0 | 2.0 | 2.0 | 1.0 | 2.0 | 1.0 | 2.0 | 2.0 | 2.0 | 2.0 | 1.0 | 1.0 | 2.0 | 1.0 | 1.0 | -237.256 |
290 | 90 | 3 | -3.56928 | 0.546927 | 0.553467 | 0.446533 | 2.0 | 2.0 | 1.0 | 2.0 | 1.0 | 2.0 | 2.0 | 1.0 | 2.0 | 1.0 | 2.0 | 1.0 | 1.0 | 2.0 | 1.0 | 1.0 | 1.0 | 1.0 | 2.0 | 1.0 | 2.0 | 2.0 | 1.0 | 1.0 | 1.0 | 2.0 | 1.0 | 2.0 | 1.0 | 2.0 | 1.0 | 1.0 | 2.0 | 2.0 | 1.0 | 1.0 | 1.0 | 2.0 | 2.0 | 2.0 | 2.0 | 1.0 | 1.0 | 2.0 | 1.0 | 2.0 | 2.0 | 2.0 | 1.0 | 2.0 | 1.0 | 2.0 | 2.0 | 2.0 | 2.0 | 1.0 | 1.0 | 2.0 | 1.0 | 1.0 | -233.319 |
291 | 91 | 3 | -3.49348 | 0.634557 | 0.634908 | 0.365092 | 2.0 | 2.0 | 1.0 | 2.0 | 1.0 | 2.0 | 2.0 | 1.0 | 2.0 | 1.0 | 2.0 | 1.0 | 1.0 | 2.0 | 1.0 | 1.0 | 1.0 | 1.0 | 2.0 | 1.0 | 2.0 | 2.0 | 1.0 | 1.0 | 1.0 | 2.0 | 1.0 | 2.0 | 1.0 | 2.0 | 1.0 | 1.0 | 2.0 | 2.0 | 1.0 | 1.0 | 1.0 | 2.0 | 2.0 | 2.0 | 2.0 | 1.0 | 1.0 | 2.0 | 1.0 | 2.0 | 2.0 | 2.0 | 1.0 | 2.0 | 1.0 | 2.0 | 1.0 | 2.0 | 2.0 | 1.0 | 1.0 | 2.0 | 1.0 | 1.0 | -234.36 |
292 | 92 | 3 | -3.40305 | 0.518597 | 0.505162 | 0.494838 | 2.0 | 2.0 | 1.0 | 2.0 | 1.0 | 2.0 | 2.0 | 1.0 | 2.0 | 1.0 | 2.0 | 1.0 | 1.0 | 2.0 | 1.0 | 1.0 | 1.0 | 1.0 | 2.0 | 1.0 | 2.0 | 2.0 | 1.0 | 1.0 | 1.0 | 2.0 | 1.0 | 2.0 | 1.0 | 2.0 | 1.0 | 1.0 | 2.0 | 2.0 | 1.0 | 1.0 | 1.0 | 2.0 | 2.0 | 2.0 | 2.0 | 1.0 | 1.0 | 2.0 | 1.0 | 2.0 | 2.0 | 2.0 | 1.0 | 2.0 | 1.0 | 2.0 | 1.0 | 2.0 | 2.0 | 1.0 | 1.0 | 2.0 | 1.0 | 1.0 | -232.412 |
293 | 93 | 3 | -3.63169 | 0.469046 | 0.420914 | 0.579086 | 2.0 | 2.0 | 1.0 | 2.0 | 1.0 | 2.0 | 2.0 | 1.0 | 2.0 | 1.0 | 2.0 | 1.0 | 1.0 | 2.0 | 1.0 | 1.0 | 1.0 | 1.0 | 2.0 | 1.0 | 2.0 | 2.0 | 1.0 | 1.0 | 1.0 | 2.0 | 1.0 | 2.0 | 1.0 | 2.0 | 1.0 | 1.0 | 2.0 | 2.0 | 1.0 | 1.0 | 1.0 | 2.0 | 2.0 | 2.0 | 2.0 | 1.0 | 1.0 | 2.0 | 1.0 | 2.0 | 2.0 | 2.0 | 1.0 | 2.0 | 1.0 | 2.0 | 2.0 | 2.0 | 2.0 | 1.0 | 1.0 | 2.0 | 1.0 | 1.0 | -233.877 |
294 | 94 | 3 | -3.50029 | 0.361316 | 0.524629 | 0.475371 | 2.0 | 2.0 | 1.0 | 2.0 | 1.0 | 2.0 | 2.0 | 1.0 | 2.0 | 1.0 | 2.0 | 1.0 | 1.0 | 2.0 | 1.0 | 1.0 | 1.0 | 1.0 | 2.0 | 1.0 | 2.0 | 2.0 | 1.0 | 1.0 | 1.0 | 2.0 | 1.0 | 2.0 | 1.0 | 2.0 | 1.0 | 1.0 | 2.0 | 2.0 | 1.0 | 1.0 | 1.0 | 2.0 | 2.0 | 2.0 | 2.0 | 1.0 | 1.0 | 2.0 | 1.0 | 2.0 | 2.0 | 2.0 | 1.0 | 2.0 | 1.0 | 2.0 | 2.0 | 2.0 | 2.0 | 1.0 | 1.0 | 2.0 | 1.0 | 1.0 | -233.238 |
295 | 95 | 3 | -3.41818 | 0.631907 | 0.546171 | 0.453829 | 2.0 | 2.0 | 1.0 | 2.0 | 1.0 | 2.0 | 2.0 | 1.0 | 2.0 | 1.0 | 2.0 | 1.0 | 1.0 | 2.0 | 1.0 | 1.0 | 1.0 | 1.0 | 2.0 | 1.0 | 2.0 | 2.0 | 1.0 | 1.0 | 1.0 | 2.0 | 1.0 | 2.0 | 1.0 | 2.0 | 1.0 | 1.0 | 2.0 | 2.0 | 1.0 | 1.0 | 1.0 | 2.0 | 2.0 | 2.0 | 2.0 | 1.0 | 1.0 | 2.0 | 1.0 | 2.0 | 2.0 | 2.0 | 1.0 | 2.0 | 1.0 | 2.0 | 1.0 | 2.0 | 2.0 | 1.0 | 1.0 | 2.0 | 1.0 | 1.0 | -232.67 |
296 | 96 | 3 | -3.68305 | 0.410643 | 0.365514 | 0.634486 | 2.0 | 2.0 | 1.0 | 2.0 | 1.0 | 2.0 | 2.0 | 1.0 | 2.0 | 1.0 | 2.0 | 1.0 | 1.0 | 2.0 | 1.0 | 1.0 | 1.0 | 1.0 | 2.0 | 1.0 | 2.0 | 2.0 | 1.0 | 1.0 | 1.0 | 2.0 | 1.0 | 2.0 | 1.0 | 2.0 | 1.0 | 1.0 | 2.0 | 2.0 | 1.0 | 1.0 | 1.0 | 2.0 | 2.0 | 2.0 | 2.0 | 1.0 | 1.0 | 2.0 | 1.0 | 2.0 | 2.0 | 2.0 | 1.0 | 2.0 | 1.0 | 2.0 | 2.0 | 2.0 | 2.0 | 1.0 | 1.0 | 2.0 | 1.0 | 1.0 | -235.958 |
297 | 97 | 3 | -3.34128 | 0.472232 | 0.6095 | 0.3905 | 2.0 | 2.0 | 1.0 | 2.0 | 1.0 | 2.0 | 2.0 | 1.0 | 2.0 | 1.0 | 2.0 | 1.0 | 1.0 | 2.0 | 1.0 | 1.0 | 1.0 | 1.0 | 2.0 | 1.0 | 2.0 | 2.0 | 1.0 | 1.0 | 1.0 | 2.0 | 1.0 | 2.0 | 1.0 | 2.0 | 1.0 | 1.0 | 2.0 | 2.0 | 1.0 | 1.0 | 1.0 | 2.0 | 2.0 | 2.0 | 2.0 | 1.0 | 1.0 | 2.0 | 1.0 | 2.0 | 2.0 | 2.0 | 1.0 | 2.0 | 1.0 | 2.0 | 2.0 | 2.0 | 2.0 | 1.0 | 1.0 | 2.0 | 1.0 | 1.0 | -235.308 |
298 | 98 | 3 | -3.54003 | 0.567434 | 0.455095 | 0.544905 | 2.0 | 2.0 | 1.0 | 2.0 | 1.0 | 2.0 | 2.0 | 1.0 | 2.0 | 1.0 | 2.0 | 1.0 | 1.0 | 2.0 | 1.0 | 1.0 | 1.0 | 1.0 | 2.0 | 1.0 | 2.0 | 2.0 | 1.0 | 1.0 | 1.0 | 2.0 | 1.0 | 2.0 | 1.0 | 2.0 | 1.0 | 1.0 | 2.0 | 2.0 | 1.0 | 1.0 | 1.0 | 2.0 | 2.0 | 2.0 | 2.0 | 1.0 | 1.0 | 2.0 | 1.0 | 2.0 | 2.0 | 2.0 | 1.0 | 2.0 | 1.0 | 2.0 | 2.0 | 2.0 | 2.0 | 1.0 | 1.0 | 2.0 | 1.0 | 1.0 | -233.264 |
299 | 99 | 3 | -3.46683 | 0.42982 | 0.626881 | 0.373119 | 2.0 | 2.0 | 1.0 | 2.0 | 1.0 | 2.0 | 2.0 | 1.0 | 2.0 | 1.0 | 2.0 | 1.0 | 1.0 | 2.0 | 1.0 | 1.0 | 1.0 | 1.0 | 2.0 | 1.0 | 2.0 | 2.0 | 1.0 | 1.0 | 1.0 | 2.0 | 1.0 | 2.0 | 1.0 | 2.0 | 1.0 | 1.0 | 2.0 | 2.0 | 1.0 | 1.0 | 1.0 | 2.0 | 2.0 | 2.0 | 2.0 | 1.0 | 1.0 | 2.0 | 1.0 | 2.0 | 2.0 | 2.0 | 1.0 | 2.0 | 1.0 | 2.0 | 1.0 | 2.0 | 2.0 | 1.0 | 1.0 | 2.0 | 1.0 | 1.0 | -234.226 |
300 | 100 | 3 | -3.49859 | 0.549345 | 0.541908 | 0.458092 | 2.0 | 2.0 | 1.0 | 2.0 | 1.0 | 2.0 | 2.0 | 1.0 | 2.0 | 1.0 | 2.0 | 1.0 | 1.0 | 2.0 | 1.0 | 1.0 | 1.0 | 1.0 | 2.0 | 1.0 | 2.0 | 2.0 | 1.0 | 1.0 | 1.0 | 2.0 | 1.0 | 2.0 | 1.0 | 2.0 | 1.0 | 1.0 | 2.0 | 2.0 | 1.0 | 1.0 | 1.0 | 2.0 | 2.0 | 2.0 | 2.0 | 1.0 | 1.0 | 2.0 | 1.0 | 2.0 | 2.0 | 2.0 | 1.0 | 2.0 | 1.0 | 2.0 | 1.0 | 2.0 | 2.0 | 1.0 | 1.0 | 2.0 | 1.0 | 1.0 | -232.39 |
In [17]:
let
# Verify that the output of the chain is as expected.
for i in MCMCChains.chains(chains)
# μ[1] and μ[2] can switch places, so we sort the values first.
chain = Array(chains[:,["μ[1]","μ[2]"],i])
μ_mean = vec(mean(chain; dims=1))
println(sort(μ_mean), μ)
@assert isapprox(sort(μ_mean), μ; rtol=0.1) "Difference between estimated mean of μ ($(sort(μ_mean))) and data-generating μ ($μ) unexpectedly large!"
end
end
[-3.44476366255952, 0.5441864646433929][-3.5, 0.5] [-3.483850095395723, 0.5179167541090655][-3.5, 0.5] [-3.4930872357138765, 0.5047422668824073][-3.5, 0.5]
In [18]:
let
Array(chains[:,["μ[1]","μ[2]"],1])
end
Out[18]:
100×2 Matrix{Float64}: -1.82734 1.54074 -3.08401 0.139212 -3.11342 0.893437 -3.24377 0.473691 -3.07499 0.902173 -3.47349 0.344424 -3.53511 0.751037 -3.3098 0.312821 -3.61012 0.649756 -3.40551 0.252875 -3.34879 0.729625 -3.54079 0.31537 -3.52774 0.672473 ⋮ -3.59809 0.540053 -3.49825 0.490688 -3.49038 0.597805 -3.43089 0.439063 -3.58848 0.568874 -3.51244 0.553333 -3.37822 0.484359 -3.51478 0.565656 -3.51768 0.604015 -3.47395 0.419305 -3.56202 0.596524 -3.61392 0.47095
In [19]:
let
[mean.(eachcol(df[i:i+99,["μ[1]","μ[2]"]])) for i = 1:100:300]
end
Out[19]:
3-element Vector{Vector{Float64}}: [-3.44476366255952, 0.5441864646433929] [-3.483850095395723, 0.5179167541090655] [-3.4930872357138765, 0.5047422668824073]
In [20]:
plot(chains[["μ[1]","μ[2]"]]; colordim=:parameter,
legend=false, titlefontsize=9 )
Out[20]:
In [21]:
plot(chains[["w[1]","w[2]"]]; colordim=:parameter, legend=false, titlefontsize=9)
Out[21]:
In [22]:
chain = chains[:,:,1]
# Model with mean of samples as parameters.
μ_mean = [mean(chain, "μ[$i]") for i in 1:2]
w_mean = [mean(chain, "w[$i]") for i in 1:2]
mixturemodel_mean = MixtureModel([MvNormal(Fill(μₖ,2), I) for μₖ in μ_mean], w_mean)
contour(
range(-7.5, 3; length=1_000),
range(-6.5, 3; length=1_000),
(x,y) -> logpdf(mixturemodel_mean,[x, y]);
widen = false)
scatter!(x[1,:], x[2,:]; legend=false, title="Sythetic Dataset")
Out[22]:
Inferred Assignments¶
마지막으로, 튜링을 사용하여 추론된 데이터 포인트의 할당을 검사할 수 있습니다. 보시다시피, 데이터 집합은 두 개의 서로 다른 그룹으로 분할되어 있습니다.
In [23]:
assignments = [mean(chain, "k[$i]") for i in 1:N]
scatter(
x[1,:],
x[2,:],
legend=false,
title="Assignments on Synthetic Dataset",
zcolor=assignments,
)
Out[23]:
In [24]:
assignments
Out[24]:
60-element Vector{Float64}: 1.95 2.0 1.0 1.95 1.0 1.99 1.78 1.0 1.99 1.0 2.0 1.0 1.0 ⋮ 1.0 1.98 1.0 1.99 1.41 2.0 1.97 1.0 1.0 2.0 1.0 1.0
In [ ]:
In [ ]: