using Pkg;Pkg.activate("probprog/workspace/");Pkg.instantiate(); IJulia.clear_output(); using Distributions, PyPlot N = 250; p_apple = 0.7; Σ = [0.2 0.1; 0.1 0.3] p_given_apple = MvNormal([1.0, 1.0], Σ) # p(X|y=apple) p_given_peach = MvNormal([1.7, 2.5], Σ) # p(X|y=peach) X = Matrix{Float64}(undef,2,N); y = Vector{Bool}(undef,N) # true corresponds to apple for n=1:N y[n] = (rand() < p_apple) # Apple or peach? X[:,n] = y[n] ? rand(p_given_apple) : rand(p_given_peach) # Sample features end X_apples = X[:,findall(y)]'; X_peaches = X[:,findall(.!y)]' # Sort features on class x_test = [2.3; 1.5] # Features of 'new' data point function plot_fruit_dataset() # Plot the data set and x_test plot(X_apples[:,1], X_apples[:,2], "r+") # apples plot(X_peaches[:,1], X_peaches[:,2], "bx") # peaches plot(x_test[1], x_test[2], "ko") # 'new' unlabelled data point legend(["Apples"; "Peaches"; "Apple or peach?"], loc=2) xlabel(L"x_1"); ylabel(L"x_2"); xlim([-1,3]); ylim([-1,4]) end plot_fruit_dataset(); # Make sure you run the data-generating code cell first # Multinomial (in this case binomial) density estimation p_apple_est = sum(y.==true) / length(y) π_hat = [p_apple_est; 1-p_apple_est] # Estimate class-conditional multivariate Gaussian densities d1 = fit_mle(FullNormal, X_apples') # MLE density estimation d1 = N(μ₁, Σ₁) d2 = fit_mle(FullNormal, X_peaches') # MLE density estimation d2 = N(μ₂, Σ₂) Σ = π_hat[1]*cov(d1) + π_hat[2]*cov(d2) # Combine Σ₁ and Σ₂ into Σ conditionals = [MvNormal(mean(d1), Σ); MvNormal(mean(d2), Σ)] # p(x|C) # Calculate posterior class probability of x∙ (prediction) function predict_class(k, X) # calculate p(Ck|X) norm = π_hat[1]*pdf(conditionals[1],X) + π_hat[2]*pdf(conditionals[2],X) return π_hat[k]*pdf(conditionals[k], X) ./ norm end println("p(apple|x=x∙) = $(predict_class(1,x_test))") # Discrimination boundary of the posterior (p(apple|x;D) = p(peach|x;D) = 0.5) β(k) = inv(Σ)*mean(conditionals[k]) γ(k) = -0.5 * mean(conditionals[k])' * inv(Σ) * mean(conditionals[k]) + log(π_hat[k]) function discriminant_x2(x1) # Solve discriminant equation for x2 β12 = β(1) .- β(2) γ12 = (γ(1) .- γ(2))[1,1] return -1*(β12[1]*x1 .+ γ12) ./ β12[2] end plot_fruit_dataset() # Plot dataset x1 = range(-1,length=10,stop=3) plot(x1, discriminant_x2(x1), "k-") # Plot discrimination boundary fill_between(x1, -1, discriminant_x2(x1), color="r", alpha=0.2) fill_between(x1, discriminant_x2(x1), 4, color="b", alpha=0.2); open("../../styles/aipstyle.html") do f display("text/html", read(f,String)) end