import numpy as np
from scipy.stats.mstats import mode
from sklearn.datasets import load_iris
from sklearn.impute import SimpleImputer as skSimpleImputer
class SimpleImputer():
def __init__(self, strategy='mean', fill_value=None):
self.strategy = strategy
self.fill_value = fill_value # only used when strategy == 'constant'
def fit(self, X):
mask = np.isnan(X)
masked_X = np.ma.masked_array(X, mask=mask)
if self.strategy == "mean":
self.statistics_ = np.array(np.ma.mean(masked_X, axis=0))
elif self.strategy == "median":
self.statistics_ = np.array(np.ma.median(masked_X, axis=0))
elif self.strategy == "most_frequent":
self.statistics_ = np.array(mode(masked_X, axis=0)[0])
elif self.strategy == "constant":
self.statistics_ = np.full(X.shape[1], self.fill_value)
return self
def transform(self, X):
mask = np.isnan(X)
n_missing = np.sum(mask, axis=0)
values = np.repeat(self.statistics_, n_missing)
coordinates = np.where(mask.T)[::-1]
Xt = X.copy()
Xt[coordinates] = values
return Xt
X, _ = load_iris(return_X_y=True)
rng = np.random.RandomState(0)
missing_samples = np.arange(X.shape[0])
missing_features = rng.choice(X.shape[1], X.shape[0])
X[missing_samples, missing_features] = np.nan
est1 = SimpleImputer(strategy="mean").fit(X)
est2 = skSimpleImputer(strategy="mean").fit(X)
assert np.allclose(est1.statistics_, est2.statistics_)
Xt1 = est1.transform(X)
Xt2 = est2.transform(X)
assert np.allclose(Xt1, Xt2)
est1 = SimpleImputer(strategy="median").fit(X)
est2 = skSimpleImputer(strategy="median").fit(X)
assert np.allclose(est1.statistics_, est2.statistics_)
Xt1 = est1.transform(X)
Xt2 = est2.transform(X)
assert np.allclose(Xt1, Xt2)
est1 = SimpleImputer(strategy="most_frequent").fit(X)
est2 = skSimpleImputer(strategy="most_frequent").fit(X)
assert np.allclose(est1.statistics_, est2.statistics_)
Xt1 = est1.transform(X)
Xt2 = est2.transform(X)
assert np.allclose(Xt1, Xt2)
est1 = SimpleImputer(strategy="constant", fill_value=0).fit(X)
est2 = skSimpleImputer(strategy="constant", fill_value=0).fit(X)
assert np.allclose(est1.statistics_, est2.statistics_)
Xt1 = est1.transform(X)
Xt2 = est2.transform(X)
assert np.allclose(Xt1, Xt2)