#!/usr/bin/env python
# coding: utf-8

# In[1]:


import numpy as np
from scipy.linalg import pinv
from scipy.special import logsumexp
from sklearn.datasets import load_iris
from sklearn.mixture import GaussianMixture as skGaussianMixture


# In[2]:


class GaussianMixture():
    def __init__(self, n_components, max_iter, init_params, random_state):
        self.n_components = n_components
        self.random_state = random_state
        self.max_iter = max_iter
        self.init_params = init_params

    def _estimate_weighted_log_prob(self, X):
        log_prob = np.zeros((X.shape[0], self.n_components))
        for k in range(self.n_components):
            diff = X - self.means_[k]
            log_prob[:, k] = (-0.5 * X.shape[1] * np.log(2 * np.pi)
                              - 0.5 * np.log(np.linalg.det(self.covariances_[k]))
                              - 0.5 * np.diag(np.dot(np.dot(diff, pinv(self.covariances_[k])), diff.T)))
        weighted_log_prob = np.log(self.weights_) + log_prob
        return weighted_log_prob

    def _estimate_log_prob_resp(self, X):
        weighted_log_prob = self._estimate_weighted_log_prob(X)
        log_prob_norm = logsumexp(weighted_log_prob, axis=1)
        log_resp = weighted_log_prob - log_prob_norm[:, np.newaxis]
        return log_prob_norm, log_resp

    def _e_step(self, X):
        log_prob_norm, log_resp = self._estimate_log_prob_resp(X)
        return np.mean(log_prob_norm), log_resp

    def _m_step(self, X, resp):
        nk = resp.sum(axis=0)
        weights = nk / X.shape[0]
        means = np.dot(resp.T, X) / nk[:, np.newaxis]
        covariances = np.empty((self.n_components, X.shape[1], X.shape[1]))
        for k in range(self.n_components):
            diff = X - means[k]
            covariances[k] = np.dot(resp[:, k] * diff.T, diff) / nk[k]
        return weights, means, covariances

    def fit(self, X):
        rng = np.random.RandomState(0)
        resp = rng.rand(X.shape[0], self.n_components)
        resp /= resp.sum(axis=1)[:, np.newaxis]
        self.weights_, self.means_, self.covariances_ = self._m_step(X, resp)
        lower_bound = -np.inf
        self.converged_ = False
        for n_iter in range(1, self.max_iter + 1):
            prev_lower_bound = lower_bound
            lower_bound, log_resp = self._e_step(X)
            self.weights_, self.means_, self.covariances_ = self._m_step(X, np.exp(log_resp))
            change = lower_bound - prev_lower_bound
            if abs(change) < 1e-3:  # consistent with scikit-learn default
                self.converged_ = True
                self.n_iter_ = n_iter
                break
        return self

    def predict_proba(self, X):
        _, log_resp = self._estimate_log_prob_resp(X)
        return np.exp(log_resp)

    def predict(self, X):
        return np.argmax(self._estimate_weighted_log_prob(X), axis=1)


# In[3]:


X, _ = load_iris(return_X_y=True)
clf1 = GaussianMixture(n_components=3, max_iter=100,
                       init_params="random", random_state=0).fit(X)
clf2 = skGaussianMixture(n_components=3, max_iter=100,
                         init_params="random", random_state=0).fit(X)
assert np.allclose(clf1.weights_, clf2.weights_, atol=1e-4)
assert np.allclose(clf1.means_, clf2.means_)
assert np.allclose(clf1.covariances_, clf2.covariances_, atol=1e-4)
prob1 = clf1.predict_proba(X)
prob2 = clf2.predict_proba(X)
assert np.allclose(prob1, prob2, atol=1e-3)
pred1 = clf1.predict(X)
pred2 = clf2.predict(X)
assert np.array_equal(pred1, pred2)