#!/usr/bin/env python
# coding: utf-8

# In[1]:


import numpy as np
from sklearn.datasets import load_iris
from sklearn.preprocessing import LabelBinarizer as skLabelBinarizer


# In[2]:


class LabelBinarizer():
    def __init__(self, pos_label=1, neg_label=0):
        self.pos_label = pos_label
        self.neg_label = neg_label

    def fit(self, y):
        self.classes_ = np.unique(y)
        return self

    def transform(self, y):
        y_enc = np.full((y.shape[0], len(self.classes_)), self.neg_label)
        for i, c in enumerate(self.classes_):
            y_enc[y == c, i] = self.pos_label
        if len(self.classes_) == 2:
            y_enc = y_enc[:, 1].reshape(-1, 1)
        return y_enc


# In[3]:


iris = load_iris()


# In[4]:


# binary
y = iris.target
y = y[y != 2]
le1 = LabelBinarizer().fit(y)
le2 = skLabelBinarizer().fit(y)
assert np.array_equal(le1.classes_, le2.classes_)
yt1 = le1.transform(y)
yt2 = le2.transform(y)
assert yt1.shape == (len(y), 1)
assert np.array_equal(yt1, yt2)


# In[5]:


# numeric multiclass
y = iris.target
le1 = LabelBinarizer().fit(y)
le2 = skLabelBinarizer().fit(y)
assert np.array_equal(le1.classes_, le2.classes_)
yt1 = le1.transform(y)
yt2 = le2.transform(y)
assert np.array_equal(yt1, yt2)


# In[6]:


# string multiclass
y = iris.target_names[iris.target]
le1 = LabelBinarizer().fit(y)
le2 = skLabelBinarizer().fit(y)
assert np.array_equal(le1.classes_, le2.classes_)
yt1 = le1.transform(y)
yt2 = le2.transform(y)
assert np.array_equal(yt1, yt2)