#!/usr/bin/env python # coding: utf-8 # In[1]: import numpy as np from scipy.spatial.distance import cdist from sklearn.datasets import load_iris from sklearn.preprocessing import StandardScaler from sklearn.neighbors import NearestCentroid as skNearestCentroid # In[2]: class NearestCentroid(): def fit(self, X, y): self.classes_ = np.unique(y) self.centroids_ = np.zeros((len(self.classes_), X.shape[1])) for i, c in enumerate(self.classes_): self.centroids_[i] = np.mean(X[y == c], axis=0) return self def predict(self, X): return self.classes_[np.argmin(cdist(X, self.centroids_), axis=1)] # In[3]: X, y = load_iris(return_X_y=True) X = StandardScaler().fit_transform(X) clf1 = NearestCentroid().fit(X, y) clf2 = skNearestCentroid().fit(X, y) assert np.allclose(clf1.centroids_, clf2.centroids_) pred1 = clf1.predict(X) pred2 = clf2.predict(X) assert np.array_equal(pred1, pred2)