using Distributions, StatsBase, StatsFuns, Plots, Random, SpecialFunctions N = 10000 # sample n Sn = rand(Categorical([0.3, 0.3, 0.4 ]), N)# latent cluster Khyper = [2 3; 8 4; 15 1] plot([0:0.01:40;], pdf.(Gamma(Khyper[1,1], Khyper[1,2]), [0:0.01:40;])) plot!([0:0.01:40;], pdf.(Gamma(Khyper[2,1], Khyper[2,2]), [0:0.01:40;])) plot!([0:0.01:40;], pdf.(Gamma(Khyper[3,1], Khyper[3,2]), [0:0.01:40;])) Random.seed!(204) λ = [rand(Gamma(Khyper[i,1], Khyper[i,2])) for i in 1:3 ] X = [ rand(Poisson(λ[Sn[i]]), 1)[1] for i in 1:length(Sn)] X' plot(fit(Histogram, X, nbins = 100)) struct MixturePoissonVB η λ shape scale α end function vb(X, nK, MAXITER = 10, a = sample([1:1:40;], nK), # Gamma hyper param(shape) b = sample([1:1:40;], nK), # Gamma hyper param(scale) α = sample([1:1:40;], nK) # Dirichrret hyper param ) N = length(X) # initialize distribution (Expectation) Sn = rand(Categorical(ones(nK)/nK), N)# cluster S = zeros(Int64, N, nK) println("Initialize latent matrix") for k in 1:nK S[findall(x->x==k, Sn), k] .= 1 end SnX = S' * X sumS = sum(S', dims = 2) println("Initialize parameter vectors") # empty vectors a1 = zeros(nK);b1 = zeros(nK);α1 = zeros(nK) lnλ1 = zeros(nK);lnπ1 = zeros(nK);λ1 = zeros(nK) η1 = zeros(N, nK) for k in 1:nK a1[k] = SnX[k] + a[k] b1[k] = sumS[k] + b[k] α1[k] = sumS[k] + α[k] λ1[k] = a1[k] / b1[k] lnλ1[k] = digamma(a1[k]) - log(b1[k]) lnπ1[k] = digamma(α1[k]) - digamma(sum(α1)) end # VB ITERATION ITER = 0 while ITER < (MAXITER + 1) print("Itaration", ITER, "... λ is ") println(λ1) ITER += 1 # Expectation of Sn for i in 1:N η1[i,:] = exp.(X[i] * lnλ1 - λ1 + lnπ1) η1[i,:] .= η1[i,:] / sum(η1[i,:]) # shoud use logsumexp ? end # Expectation of λ and π ηX = η1' * X sumη = sum(η1', dims = 2) # total probability of each cluster for k in 1:nK a1[k] = ηX[k] + a[k] b1[k] = sumη[k] + b[k] α1[k] = sumη[k] + α[k] λ1[k] = a1[k] / b1[k] lnλ1[k] = digamma(a1[k]) - log(b1[k]) lnπ1[k] = digamma(α1[k]) - digamma(sum(α1)) end end # of while MixturePoissonVB(η1, λ1, a1, b1, α1) end # of function λ # 真値 @time res1 = vb(X, 3, 100) @time res2 = vb(X, 3, 100, Khyper[:,1], Khyper[:,2], [1.0,1.0,1.0]) [res2.η X Sn] using DataFrames df_tmp = [res2.η[:,2] res2.η[:,1] res2.η[:,3]] df = convert(DataFrames.DataFrame, [df_tmp X Sn]) which_max = zeros(Int64, N) for i in 1:N which_max[i] = findall(x->x==maximum(df_tmp[i,:]), df_tmp[i,:])[] end which_max' error_count = Sn[Sn .!== which_max] countmap(error_count) println("Error rate is ", count(!iszero, which_max - Sn) / N * 100 ,"%") plot(fit(Histogram, error_count), xlab = "Latent Cluster")