Tutorial - Gaussian Mixture Model¶

In [278]:
using Distributions
using FillArrays
using StatsPlots

using LinearAlgebra
using Random
plotlyjs()
Out[278]:
Plots.PlotlyJSBackend()
In [279]:
Random.seed!(3)
Out[279]:
TaskLocalRNG()
In [280]:
# w는 합해서 1이 되어야 한다.
w = [0.5, 0.5]
μ = [-3.5, 0.5]

mixturemodel = MixtureModel([MvNormal(Fill(μₖ,2), I) for μₖ in μ], w)
Out[280]:
MixtureModel{MvNormal{Float64, PDMats.ScalMat{Float64}, Fill{Float64, 1, Tuple{Base.OneTo{Int64}}}}}(K = 2)
components[1] (prior = 0.5000): MvNormal{Float64, PDMats.ScalMat{Float64}, Fill{Float64, 1, Tuple{Base.OneTo{Int64}}}}(
dim: 2
μ: Fill(-3.5, 2)
Σ: [1.0 0.0; 0.0 1.0]
)

components[2] (prior = 0.5000): MvNormal{Float64, PDMats.ScalMat{Float64}, Fill{Float64, 1, Tuple{Base.OneTo{Int64}}}}(
dim: 2
μ: Fill(0.5, 2)
Σ: [1.0 0.0; 0.0 1.0]
)

In [281]:
N = 6000
x = rand(mixturemodel, N)
density(x[1,:])
density!(x[2,:])
Out[281]:
In [282]:
mixture_pdf(x) = pdf(mixturemodel,x)
mixture_pdf([[0.0,4.0],[-1,2]])
Out[282]:
2-element Vector{Float64}:
 0.00015362065909652362
 0.00838740473741466
In [283]:
xs = [collect(t) for t in zip(x[1,:],x[2,:])];
In [284]:
xs_pdf = mixture_pdf(x);
In [285]:
scatter3d(x[1,:],x[2,:],xs_pdf; markersize=0.1, legend=false, 
    title="Synthetic Dataset", color=:red, size=(800,400) )
Out[285]:

Gaussian Mixture Model in Turing¶

주어진 데이터 세트에서 그룹핑을 복구하고자 한다. 즉 혼합 가중치(weight)와 모수(parameter) $\mu_1$,$\mu_2$ 그리고 생성 가우시안 혼합 모델을 위한 각 데이터의 클러스터 할당을 추론 하고자 한다

$k$개의 클러스터로 구성된 데이터 $x_i\;(i=1,...,N)$은 다음 생성 프로세스에 따라 생성된다.

    1. 모델 파라미터 즉 $\mu_k$를 뽑는다. $\mu_k \sim \mathcal{N}(0,1)\;\;(k=1,...,K)$, $k$가 뽑힐 가중치 $w$는 $w \sim Dirichlet(\alpha_1,...,\alpha_k)$
    1. 모델 파라미터 설정이 끝났어면 관찰 데이터$(N)$개를 생성 하기 위해 클러스터를 선택 하고 그 클러스터 근처에 있는 관찰 데이터를 뽑는다.
    • $z_i \sim Categorical(w),\;\;\;(i=1,...,N)$
    • $x_i \sim \mathcal{N}([\mu_{z_i},\mu_{z_i}]^T,I),\;\;\;(i=1,...,N)$

디리클레분포(Dirichlet distribution)는 베타분포의 확장판이라고 할 수 있다. 베타분포는 0과 1사이의 값을 가지는 단일(univariate) 확률변수의 베이지안 모형에 사용되고 디리클레분포는 0과 1사이의 사이의 값을 가지는 다변수(multivariate) 확률변수의 베이지안 모형에 사용된다.

In [286]:
let
    x = rand(Dirichlet(2,1.0),1000000)
    density(x[1,:])
    density!(x[2,:])
end
Out[286]:
In [287]:
w = rand(Dirichlet(2,1.0))
size(w) |> display
w |> display
rand(Categorical(w),5)
(2,)
2-element Vector{Float64}:
 0.3645928706520138
 0.6354071293479864
Out[287]:
5-element Vector{Int64}:
 2
 2
 1
 1
 2
In [288]:
rand(MvNormal(Zeros(2), I),1)
Out[288]:
2×1 Matrix{Float64}:
  1.816361101257206
 -1.1862323182615306
In [289]:
using Turing
using DataFrames
In [290]:
@model function gaussian_mixture_model(x)
    # Draw the parameters for each of the K=2 clusters from a standard normal distribution.
    K = 2
    # 각 축별로 값을 추출한다.
    # 예) [-0.82 ,1.14]
    #     각 축별로 추출한 값이 각각 중심이 아래와 같은 클러스터가 된다 
    #     [-0.82, -0,82], [1.14, 1.14]
    μ ~ MvNormal(Zeros(K), I)
    
    # Draw the weights for the K clusters from a Dirichlet distribution with parameters αₖ = 1.
    # 어떤 클러스를 뽑을지 가중치 결정
    # ∑wᵢ = 1
    w ~ Dirichlet(K, 1.0)
    
    # Construct categorical distribution of assignments.
    # 가중치에 따른 클러스트 선택
    distribution_assignments = Categorical(w)
    
    # Construct categorical distribution of assignments.
    # D : dimension, N : data count
    D, N = size(x)
    #  데이터를 뽑을 각 클러스터의 분포 설정
    distribution_clusters = [MvNormal(Fill(μₖ, D), I) for μₖ in μ]
    
    # Draw assignments for each datum and generate it from the multivariate
    # normal distribution.
    k = Vector{Int}(undef,N)
    for i in 1:N
        # 클러스터 선택
        k[i] ~ distribution_assignments
        # x는 선택된 클러스터의 분포를 따른다.
        x[:,i] ~ distribution_clusters[k[i]]
    end
    
    return k
end
Out[290]:
gaussian_mixture_model (generic function with 2 methods)
In [291]:
N = 60
x = rand(mixturemodel, N)
model = gaussian_mixture_model(x)
Out[291]:
DynamicPPL.Model{typeof(gaussian_mixture_model), (:x,), (), (), Tuple{Matrix{Float64}}, Tuple{}, DynamicPPL.DefaultContext}(gaussian_mixture_model, (x = [-1.995712209060883 -3.5535472667182626 … -4.4501339298974445 1.0566045473216428; -0.43691107556630615 -2.843174596703907 … -4.036749975745415 -0.65206523018811],), NamedTuple(), DynamicPPL.DefaultContext())

매개변수 𝜇 및 𝑤와 할당 𝑘 의 사후 분포 근사치를 얻기 위해 MCMC 시뮬레이션을 실행합니다. 이산 파라미터(할당 𝑘)에는 입자 깁스 샘플러를, 연속 파라미터(𝜇 및 𝑤)에는 해밀토니온 몬테카를로 샘플러를 결합한 깁스 샘플러를 사용합니다. 멀티 스레딩을 사용하여 여러 체인을 병렬로 생성합니다.

In [292]:
g_sampler = Gibbs(PG(100,:k), HMC(0.05, 10, :μ, :w))
n_samples = 100
n_chains = 3
chains = sample(model, g_sampler, MCMCThreads(), n_samples, n_chains)
Sampling (3 threads): 100%|█████████████████████████████| Time: 0:00:04
Out[292]:
Chains MCMC chain (100×65×3 Array{Float64, 3}):

Iterations        = 1:1:100
Number of chains  = 3
Samples per chain = 100
Wall duration     = 320.95 seconds
Compute duration  = 957.58 seconds
parameters        = μ[1], μ[2], w[1], w[2], k[1], k[2], k[3], k[4], k[5], k[6], k[7], k[8], k[9], k[10], k[11], k[12], k[13], k[14], k[15], k[16], k[17], k[18], k[19], k[20], k[21], k[22], k[23], k[24], k[25], k[26], k[27], k[28], k[29], k[30], k[31], k[32], k[33], k[34], k[35], k[36], k[37], k[38], k[39], k[40], k[41], k[42], k[43], k[44], k[45], k[46], k[47], k[48], k[49], k[50], k[51], k[52], k[53], k[54], k[55], k[56], k[57], k[58], k[59], k[60]
internals         = lp

Summary Statistics
  parameters      mean       std      mcse   ess_bulk   ess_tail               ⋯
      Symbol   Float64   Float64   Float64    Float64    Float64               ⋯

        μ[1]   -0.8084    1.9205    1.0593     5.1302    63.4218               ⋯
        μ[2]   -2.1337    1.9045    1.0443     5.1449    93.6059               ⋯
        w[1]    0.4946    0.0789    0.0074   236.3241   127.7179               ⋯
        w[2]    0.5054    0.0789    0.0074   236.3241   127.7179               ⋯
        k[1]    1.4133    0.4933    0.2142     5.3004        NaN               ⋯
        k[2]    1.6667    0.4722    0.2615     3.2609        NaN               ⋯
        k[3]    1.3600    0.4808    0.2505     3.6839        NaN               ⋯
        k[4]    1.6667    0.4722    0.2615     3.2609        NaN               ⋯
        k[5]    1.3467    0.4767    0.2578     3.4192        NaN               ⋯
        k[6]    1.6667    0.4722    0.2615     3.2609        NaN               ⋯
        k[7]    1.6667    0.4722    0.2615     3.2609        NaN               ⋯
        k[8]    1.3333    0.4722    0.2615     3.2609        NaN   15752611252 ⋯
        k[9]    1.6667    0.4722    0.2615     3.2609        NaN               ⋯
       k[10]    1.6667    0.4722    0.2615     3.2609        NaN               ⋯
       k[11]    1.6667    0.4722    0.2615     3.2609        NaN               ⋯
       k[12]    1.6667    0.4722    0.2615     3.2609        NaN               ⋯
       k[13]    1.3433    0.4756    0.2587     3.3812        NaN               ⋯
       k[14]    1.6667    0.4722    0.2615     3.2609        NaN               ⋯
       k[15]    1.3433    0.4756    0.2553     3.4714        NaN               ⋯
       k[16]    1.3433    0.4756    0.2587     3.3812        NaN               ⋯
       k[17]    1.6667    0.4722    0.2615     3.2609        NaN               ⋯
       k[18]    1.6600    0.4745    0.2594     3.3450        NaN               ⋯
       k[19]    1.6667    0.4722    0.2615     3.2609        NaN               ⋯
      ⋮           ⋮         ⋮         ⋮         ⋮          ⋮                 ⋮ ⋱
                                                   2 columns and 41 rows omitted

Quantiles
  parameters      2.5%     25.0%     50.0%     75.0%     97.5% 
      Symbol   Float64   Float64   Float64   Float64   Float64 

        μ[1]   -3.7010   -3.4167    0.4261    0.5800    0.8828
        μ[2]   -3.7804   -3.5509   -3.3976    0.4003    0.7006
        w[1]    0.3701    0.4606    0.4989    0.5437    0.6075
        w[2]    0.3925    0.4563    0.5011    0.5394    0.6299
        k[1]    1.0000    1.0000    1.0000    2.0000    2.0000
        k[2]    1.0000    1.0000    2.0000    2.0000    2.0000
        k[3]    1.0000    1.0000    1.0000    2.0000    2.0000
        k[4]    1.0000    1.0000    2.0000    2.0000    2.0000
        k[5]    1.0000    1.0000    1.0000    2.0000    2.0000
        k[6]    1.0000    1.0000    2.0000    2.0000    2.0000
        k[7]    1.0000    1.0000    2.0000    2.0000    2.0000
        k[8]    1.0000    1.0000    1.0000    2.0000    2.0000
        k[9]    1.0000    1.0000    2.0000    2.0000    2.0000
       k[10]    1.0000    1.0000    2.0000    2.0000    2.0000
       k[11]    1.0000    1.0000    2.0000    2.0000    2.0000
       k[12]    1.0000    1.0000    2.0000    2.0000    2.0000
       k[13]    1.0000    1.0000    1.0000    2.0000    2.0000
       k[14]    1.0000    1.0000    2.0000    2.0000    2.0000
       k[15]    1.0000    1.0000    1.0000    2.0000    2.0000
       k[16]    1.0000    1.0000    1.0000    2.0000    2.0000
       k[17]    1.0000    1.0000    2.0000    2.0000    2.0000
       k[18]    1.0000    1.0000    2.0000    2.0000    2.0000
       k[19]    1.0000    1.0000    2.0000    2.0000    2.0000
      ⋮           ⋮         ⋮         ⋮         ⋮         ⋮
                                                 41 rows omitted
In [293]:
df = DataFrame(chains)
Out[293]:
300×67 DataFrame
275 rows omitted
Rowiterationchainμ[1]μ[2]w[1]w[2]k[1]k[2]k[3]k[4]k[5]k[6]k[7]k[8]k[9]k[10]k[11]k[12]k[13]k[14]k[15]k[16]k[17]k[18]k[19]k[20]k[21]k[22]k[23]k[24]k[25]k[26]k[27]k[28]k[29]k[30]k[31]k[32]k[33]k[34]k[35]k[36]k[37]k[38]k[39]k[40]k[41]k[42]k[43]k[44]k[45]k[46]k[47]k[48]k[49]k[50]k[51]k[52]k[53]k[54]k[55]k[56]k[57]k[58]k[59]k[60]lp
Int64Int64Float64Float64Float64Float64Float64Float64Float64Float64Float64Float64Float64Float64Float64Float64Float64Float64Float64Float64Float64Float64Float64Float64Float64Float64Float64Float64Float64Float64Float64Float64Float64Float64Float64Float64Float64Float64Float64Float64Float64Float64Float64Float64Float64Float64Float64Float64Float64Float64Float64Float64Float64Float64Float64Float64Float64Float64Float64Float64Float64Float64Float64Float64Float64Float64Float64
1110.806586-1.469890.162770.837232.02.02.02.01.02.02.01.02.02.02.02.01.02.02.01.02.02.02.02.02.01.02.02.02.02.02.01.02.02.02.01.01.02.02.02.02.02.02.01.02.02.02.02.02.02.02.01.02.02.02.01.02.02.02.02.02.02.02.01.0-399.401
2211.00127-3.09380.4370760.5629242.02.02.02.01.02.02.01.02.02.02.02.01.02.01.01.02.02.02.01.02.01.02.01.02.02.02.01.02.02.02.01.01.02.01.02.02.01.02.02.01.02.02.02.02.01.02.01.02.02.02.01.02.02.01.01.01.02.02.01.0-296.311
3311.00127-3.09380.4370760.5629242.02.02.02.01.02.02.01.02.02.02.02.01.02.01.01.02.02.02.01.02.01.02.01.02.02.02.01.02.02.02.01.01.01.01.02.02.01.02.01.01.01.02.01.02.01.02.01.01.02.02.01.02.02.01.01.01.01.02.01.0-252.515
4410.460392-3.590380.4186980.5813022.02.02.02.01.02.02.01.02.02.02.02.01.02.01.01.02.02.02.01.02.01.02.01.02.02.02.01.02.02.02.01.01.01.01.02.02.01.02.01.01.01.02.01.02.01.02.01.01.02.02.01.02.02.01.01.01.01.02.01.0-254.984
5510.890698-2.828790.4695830.5304172.02.02.02.01.02.02.01.02.02.02.02.01.02.01.01.02.02.02.01.02.01.02.01.02.02.02.01.02.02.02.01.01.01.01.02.02.01.02.01.01.01.02.01.02.01.01.01.02.02.02.01.02.02.01.01.01.01.02.01.0-261.595
6610.352181-3.785780.5347440.4652562.02.01.02.01.02.02.01.02.02.02.02.01.02.01.01.02.02.02.01.02.01.02.01.02.02.01.01.02.02.02.01.01.01.01.02.02.01.02.01.01.01.02.01.02.01.02.01.01.02.02.01.02.02.01.01.01.01.02.01.0-238.056
7710.591203-3.294330.4262170.5737832.02.01.02.01.02.02.01.02.02.02.02.01.02.01.01.02.02.02.01.02.01.02.01.02.02.01.01.02.02.02.01.01.01.01.02.02.01.02.01.01.01.02.01.02.01.01.01.01.01.02.01.02.02.01.01.01.01.02.01.0-240.432
8810.741131-3.544890.4437920.5562082.02.01.02.01.02.02.01.02.02.02.02.01.02.01.01.02.02.02.01.02.01.02.01.02.02.01.01.02.02.02.01.01.01.01.02.02.01.02.01.01.01.02.01.02.01.02.01.01.02.02.01.02.02.01.01.01.01.02.01.0-232.511
9910.651008-3.449310.4779320.5220682.02.01.02.01.02.02.01.02.02.02.02.01.02.01.01.02.02.02.01.02.01.02.01.02.02.01.01.02.02.02.01.01.01.01.02.02.01.02.01.01.01.02.01.02.01.02.01.01.02.02.01.02.02.01.01.01.01.02.01.0-231.401
101010.359191-3.496570.5203040.4796962.02.01.02.01.02.02.01.02.02.02.02.01.02.01.01.02.02.02.01.02.01.02.01.02.02.01.01.02.02.02.01.01.01.01.02.02.01.02.01.01.01.02.01.02.01.01.01.01.02.02.01.02.02.01.01.01.01.02.01.0-233.431
111110.764042-3.320440.3751320.6248681.02.01.02.01.02.02.01.02.02.02.02.01.02.01.01.02.02.02.01.02.01.02.01.02.02.01.01.02.02.02.01.01.01.01.02.02.01.02.01.01.01.02.01.02.01.02.01.01.02.02.01.02.02.01.01.01.01.02.01.0-233.765
121210.483242-3.427880.4819130.5180872.02.01.02.01.02.02.01.02.02.02.02.01.02.01.01.02.02.02.01.02.01.02.01.02.02.01.01.02.02.02.01.01.01.01.02.02.01.02.01.01.01.02.01.02.01.02.01.01.02.02.01.02.02.01.01.01.01.02.01.0-231.863
131310.542634-3.425310.4464960.5535042.02.01.02.01.02.02.01.02.02.02.02.01.02.01.01.02.02.02.01.02.01.02.01.02.02.01.01.02.02.02.01.01.01.01.02.02.01.02.01.01.01.02.01.02.01.01.01.01.02.02.01.02.02.01.01.01.01.02.01.0-232.734
⋮⋮⋮⋮⋮⋮⋮⋮⋮⋮⋮⋮⋮⋮⋮⋮⋮⋮⋮⋮⋮⋮⋮⋮⋮⋮⋮⋮⋮⋮⋮⋮⋮⋮⋮⋮⋮⋮⋮⋮⋮⋮⋮⋮⋮⋮⋮⋮⋮⋮⋮⋮⋮⋮⋮⋮⋮⋮⋮⋮⋮⋮⋮⋮⋮⋮⋮⋮
2898930.431869-3.463590.518840.481161.02.01.02.01.02.02.01.02.02.02.02.01.02.01.01.02.02.02.01.02.01.02.01.02.02.01.01.02.02.02.01.01.01.01.02.02.01.02.01.01.01.02.01.02.01.01.01.01.02.02.01.02.02.01.01.01.01.02.01.0-230.338
2909030.382087-3.59370.3743340.6256661.02.01.02.01.02.02.01.02.02.02.02.01.02.01.01.02.02.02.01.02.01.02.01.02.02.01.01.02.02.02.01.01.01.01.02.02.01.02.01.01.01.02.01.02.01.01.01.01.02.02.01.02.02.01.01.01.01.02.01.0-232.983
2919130.382087-3.59370.3743340.6256661.02.01.02.01.02.02.01.02.02.02.02.01.02.01.01.02.02.02.01.02.01.02.01.02.02.01.01.02.02.02.01.01.01.01.02.02.01.02.01.01.01.02.01.02.01.01.01.01.02.02.01.02.02.01.01.01.01.02.01.0-232.983
2929230.51947-3.609650.5644780.4355221.02.01.02.01.02.02.01.02.02.02.02.01.02.01.01.02.02.02.01.02.01.02.01.02.02.01.01.02.02.02.01.01.01.01.02.02.01.02.01.01.01.02.01.02.01.01.01.01.02.02.01.02.02.01.01.01.01.02.01.0-230.41
2939330.46369-3.477640.5156130.4843871.02.01.02.01.02.02.01.02.02.02.02.01.02.01.01.02.02.02.01.02.01.02.01.02.02.01.01.02.02.02.01.01.01.01.02.02.01.02.01.01.01.02.01.02.01.01.01.01.02.02.01.02.02.01.01.01.01.02.01.0-230.184
2949430.362756-3.60610.5440.4561.02.01.02.01.02.02.01.02.02.02.02.01.02.01.01.02.02.02.01.02.01.02.01.02.02.01.01.02.02.02.01.01.01.01.02.02.01.02.01.01.01.02.01.02.01.01.01.01.02.02.01.02.02.01.01.01.01.02.01.0-230.689
2959530.631187-3.548050.4723720.5276281.02.01.02.01.02.02.01.02.02.02.02.01.02.01.01.02.02.02.01.02.01.02.01.02.02.01.01.02.02.02.01.01.01.01.02.02.01.02.01.01.01.02.01.02.01.01.01.01.02.02.01.02.02.01.01.01.01.02.01.0-230.857
2969630.483107-3.43990.5674080.4325921.02.01.02.01.02.02.01.02.02.02.02.01.02.01.01.02.02.02.01.02.01.02.01.02.02.01.01.02.02.02.01.01.01.01.02.02.01.02.01.01.01.02.01.02.01.02.01.01.02.02.01.02.02.01.01.01.01.02.01.0-230.423
2979730.490669-3.48310.4020050.5979951.02.01.02.01.02.02.01.02.02.02.02.01.02.01.01.02.02.02.01.02.01.02.01.02.02.01.01.02.02.02.01.01.01.01.02.02.01.02.01.01.01.02.01.02.01.02.01.01.02.02.01.02.02.01.01.01.01.02.01.0-230.969
2989830.596414-3.304520.4409850.5590151.02.01.02.01.02.02.01.02.02.02.02.01.02.01.01.02.02.02.01.02.01.02.01.02.02.01.01.02.02.02.01.01.01.01.02.02.01.02.01.01.01.02.01.02.01.02.01.01.02.02.01.02.02.01.01.01.01.02.01.0-231.115
2999930.769064-3.76950.5112880.4887121.02.01.02.01.02.02.01.02.02.02.02.01.02.01.01.02.02.02.01.02.01.02.01.02.02.01.01.02.02.02.01.01.01.01.02.02.01.02.01.01.01.02.01.02.01.02.01.01.02.02.01.02.02.01.01.01.01.02.01.0-233.438
30010030.175728-3.389950.469760.530241.02.01.02.01.02.02.01.02.02.02.02.01.02.01.01.02.02.02.01.02.01.02.01.02.02.01.01.02.02.02.01.01.01.01.02.02.01.02.01.01.01.02.01.02.01.02.01.01.02.02.01.02.02.01.01.01.01.02.01.0-234.459
In [294]:
let
    # Verify that the output of the chain is as expected.
    for i in MCMCChains.chains(chains)
        # μ[1] and μ[2] can switch places, so we sort the values first.
        chain = Array(chains[:,["μ[1]","μ[2]"],i])
        μ_mean = vec(mean(chain; dims=1))
        println(sort(μ_mean), μ)
        @assert isapprox(sort(μ_mean), μ; rtol=0.1) "Difference between estimated mean of μ ($(sort(μ_mean))) and data-generating μ ($μ) unexpectedly large!"
    end
end
[-3.4843836622735846, 0.5394186230220864][-3.5, 0.5]
[-3.5072251084217667, 0.5232151482050279][-3.5, 0.5]
[-3.4400456293350423, 0.5426133471604079][-3.5, 0.5]
In [295]:
let    
   Array(chains[:,["μ[1]","μ[2]"],1])
end
Out[295]:
100×2 Matrix{Float64}:
 0.806586  -1.46989
 1.00127   -3.0938
 1.00127   -3.0938
 0.460392  -3.59038
 0.890698  -2.82879
 0.352181  -3.78578
 0.591203  -3.29433
 0.741131  -3.54489
 0.651008  -3.44931
 0.359191  -3.49657
 0.764042  -3.32044
 0.483242  -3.42788
 0.542634  -3.42531
 ⋮         
 0.599547  -3.24166
 0.361801  -3.70378
 0.720165  -3.33989
 0.283199  -3.72268
 0.758261  -3.33467
 0.225492  -3.59689
 0.623473  -3.57775
 0.273148  -3.61686
 0.758842  -3.54824
 0.355432  -3.47965
 0.635855  -3.5731
 0.383734  -3.55743
In [296]:
let
    [mean.(eachcol(df[i:i+99,["μ[1]","μ[2]"]])) for i = 1:100:300]
end
Out[296]:
3-element Vector{Vector{Float64}}:
 [0.5394186230220864, -3.4843836622735846]
 [-3.5072251084217667, 0.5232151482050279]
 [0.5426133471604079, -3.4400456293350423]

Inffered Mixture Model¶

샘플링 후 관심 있는 파라미터의 추적과 밀도를 시각화할 수 있습니다.

위치 매개변수 $\mu_1$ 과 $\mu_2$의 샘플을 고려합니다.

In [297]:
plot(chains[["μ[1]","μ[2]"]]; colordim=:parameter, 
    legend=false, titlefontsize=9  )
Out[297]:
In [298]:
plot(chains[["w[1]","w[2]"]]; colordim=:parameter, legend=false, titlefontsize=9)
Out[298]:
In [299]:
chain = chains[:,:,1]

# Model with mean of samples as parameters.
μ_mean = [mean(chain, "μ[$i]") for i in 1:2]
w_mean = [mean(chain, "w[$i]") for i in 1:2]

mixturemodel_mean = MixtureModel([MvNormal(Fill(μₖ,2), I) for μₖ in μ_mean], w_mean)

contour(
    range(-7.5, 3; length=1_000),
    range(-6.5, 3; length=1_000),
    (x,y) -> logpdf(mixturemodel_mean,[x, y]);
    widen = false)
scatter!(x[1,:], x[2,:]; legend=false, title="Sythetic Dataset")
Out[299]:

Inferred Assignments¶

마지막으로, 튜링을 사용하여 추론된 데이터 포인트의 할당을 검사할 수 있습니다. 보시다시피, 데이터 집합은 두 개의 서로 다른 그룹으로 분할되어 있습니다.

In [300]:
assignments = [mean(chain, "k[$i]") for i in 1:N]
scatter(
    x[1,:],
    x[2,:],
    legend=false,
    title="Assignments on Synthetic Dataset",
    zcolor=assignments,
    )
Out[300]:
In [301]:
assignments
Out[301]:
60-element Vector{Float64}:
 1.19
 2.0
 1.05
 2.0
 1.0
 2.0
 2.0
 1.0
 2.0
 2.0
 2.0
 2.0
 1.0
 ⋮
 1.03
 1.99
 2.0
 1.0
 2.0
 2.0
 1.01
 1.01
 1.01
 1.02
 2.0
 1.0
In [ ]:

In [ ]: