#!/usr/bin/env python
# coding: utf-8

# In[1]:


import numpy as np
from scipy.special import expit, logsumexp
from sklearn.datasets import load_breast_cancer, load_iris
from sklearn.tree import DecisionTreeRegressor
from sklearn.ensemble import GradientBoostingClassifier as skGradientBoostingClassifier


# In[2]:


class GradientBoostingClassifier():
    def __init__(self, learning_rate=0.1, n_estimators=100, max_depth=3, random_state=0):
        self.learning_rate = learning_rate
        self.n_estimators = n_estimators
        self.max_depth = max_depth
        self.random_state = random_state

    def fit(self, X, y):
        self.n_features_ = X.shape[1]
        self.classes_, y = np.unique(y, return_inverse=True)
        self.n_classes_ = len(self.classes_)
        if self.n_classes_ == 2:
            n_effective_classes = 1
        else:
            n_effective_classes = self.n_classes_
        self.estimators_ = np.empty((self.n_estimators, n_effective_classes), dtype=np.object)
        raw_predictions = np.zeros((X.shape[0], n_effective_classes))
        rng = np.random.RandomState(0)
        for i in range(self.n_estimators):
            raw_predictions_copy = raw_predictions.copy()
            for j in range(n_effective_classes):
                # binary classification
                if n_effective_classes == 1:
                    y_enc = y
                    residual = y_enc - expit(raw_predictions_copy.ravel())
                # multiclass classification
                else:
                    y_enc = (y == j).astype(np.int)
                    residual = y_enc - np.nan_to_num(np.exp(raw_predictions_copy[:, j] 
                                                            - logsumexp(raw_predictions_copy, axis=1)))
                tree = DecisionTreeRegressor(criterion="friedman_mse", max_depth=self.max_depth,
                                             random_state=rng)
                tree.fit(X, residual)
                terminal_regions = tree.apply(X)
                for leaf in np.where(tree.tree_.children_left == -1)[0]:
                    cur = np.where(terminal_regions == leaf)[0]
                    # binary classification
                    if n_effective_classes == 1:
                        numerator = np.sum(residual[cur])
                        denominator = np.sum((y_enc[cur] - residual[cur]) * (1 - y_enc[cur] + residual[cur]))
                    # multiclass classification
                    else:
                        numerator = np.sum(residual[cur])
                        numerator *= (self.n_classes_ - 1) / self.n_classes_
                        denominator = np.sum((y_enc[cur] - residual[cur]) * (1 - y_enc[cur] + residual[cur]))
                    if np.abs(denominator) < 1e-150:
                        tree.tree_.value[leaf, 0, 0] = 0
                    else:
                        tree.tree_.value[leaf, 0, 0] = numerator / denominator
                raw_predictions[:, j] += self.learning_rate * tree.tree_.value[:, 0, 0][terminal_regions]
                self.estimators_[i, j] = tree
        return self

    def _predict(self, X):
        raw_predictions = np.zeros((X.shape[0], self.estimators_.shape[1]))
        for i in range(self.estimators_.shape[0]):
            for j in range(self.estimators_.shape[1]):
                raw_predictions[:, j] += self.learning_rate * self.estimators_[i, j].predict(X)
        return raw_predictions

    def decision_function(self, X):
        prob = self._predict(X)
        if self.n_classes_ == 2:
            return prob.ravel()
        else:
            return prob

    def predict_proba(self, X):
        scores = self.decision_function(X)
        if len(scores.shape) == 1:
            prob = expit(scores)
            prob = np.vstack((1 - prob, prob)).T
        else:
            prob = np.nan_to_num(np.exp(scores - logsumexp(scores, axis=1)[:, np.newaxis]))
        return prob

    def predict(self, X):
        scores = self.decision_function(X)
        if len(scores.shape) == 1:
            indices = (scores > 0).astype(int)
        else:
            indices = np.argmax(scores, axis=1)
        return self.classes_[indices]   

    @property
    def feature_importances_(self):
        all_importances = np.zeros(self.n_features_)
        for i in range(self.estimators_.shape[0]):
            for j in range(self.estimators_.shape[1]):
                all_importances += self.estimators_[i, j].tree_.compute_feature_importances(normalize=False)
        return all_importances / np.sum(all_importances)


# In[3]:


# binary classification
X, y = load_breast_cancer(return_X_y=True)
clf1 = GradientBoostingClassifier(n_estimators=10).fit(X, y)
clf2 = skGradientBoostingClassifier(n_estimators=10, init="zero", presort=False, random_state=0).fit(X, y)
assert np.allclose(clf1.feature_importances_, clf2.feature_importances_)
prob1 = clf1.decision_function(X)
prob2 = clf2.decision_function(X)
assert np.allclose(prob1, prob2)
prob1 = clf1.predict_proba(X)
prob2 = clf2.predict_proba(X)
assert np.allclose(prob1, prob2)
pred1 = clf1.predict(X)
pred2 = clf2.predict(X)
assert np.array_equal(pred1, pred2)


# In[4]:


# multiclass classification
X, y = load_iris(return_X_y=True)
clf1 = GradientBoostingClassifier(n_estimators=10).fit(X, y)
clf2 = skGradientBoostingClassifier(n_estimators=10, init="zero", presort=False, random_state=0).fit(X, y)
assert np.allclose(clf1.feature_importances_, clf2.feature_importances_)
prob1 = clf1.decision_function(X)
prob2 = clf2.decision_function(X)
assert np.allclose(prob1, prob2)
prob1 = clf1.predict_proba(X)
prob2 = clf2.predict_proba(X)
assert np.allclose(prob1, prob2)
pred1 = clf1.predict(X)
pred2 = clf2.predict(X)
assert np.array_equal(pred1, pred2)