#!/usr/bin/env python
# coding: utf-8

# In[1]:


import numpy as np
from sklearn.datasets import fetch_20newsgroups
from sklearn.feature_extraction.text import CountVectorizer
from sklearn.naive_bayes import ComplementNB as skComplementNB


# In[2]:


class ComplementNB():
    def __init__(self, alpha=1.0):
        self.alpha = alpha

    def _encode(self, y):
        classes = np.unique(y)
        y_train = np.zeros((y.shape[0], len(classes)))
        for i, c in enumerate(classes):
            y_train[y == c, i] = 1
        return classes, y_train

    def fit(self, X, y):
        self.classes_, y_train = self._encode(y)
        self.feature_count_ = np.dot(y_train.T, X)
        self.feature_all_ = self.feature_count_.sum(axis=0)
        smoothed_fc = self.feature_all_ - self.feature_count_ + self.alpha
        smoothed_cc = smoothed_fc.sum(axis=1)
        self.feature_log_prob_ = (np.log(smoothed_fc) -
                                  np.log(smoothed_cc.reshape(-1, 1)))
        self.feature_log_prob_ /= -self.feature_log_prob_.sum(axis=1).reshape(-1, 1)
        return self

    def _joint_log_likelihood(self, X):
        return np.dot(X, self.feature_log_prob_.T)

    def predict(self, X):
        joint_log_likelihood = self._joint_log_likelihood(X)
        return self.classes_[np.argmin(joint_log_likelihood, axis=1)]


# In[3]:


data_train = fetch_20newsgroups()
X, y = data_train.data, data_train.target
# convert to dense since we do not support sparse very well
X = CountVectorizer(min_df=0.001).fit_transform(X).toarray()
clf1 = ComplementNB().fit(X, y)
clf2 = skComplementNB(norm=True).fit(X, y)
assert np.allclose(-clf1.feature_log_prob_, clf2.feature_log_prob_)
prob1 = clf1._joint_log_likelihood(X)
prob2 = clf2._joint_log_likelihood(X)
assert np.allclose(-prob1, prob2)
pred1 = clf1.predict(X)
pred2 = clf2.predict(X)
assert np.array_equal(pred1, pred2)