#!/usr/bin/env python # coding: utf-8 # # K-Means Clustering in Python -- Synthetic Data # This jupyter notebook accompanies the Real Python's [Practical K-Means Clustering in Python](https://realpython.com/practical-k-means-clustering-python/) article. Specifically, this notebook works through the **first half** of the article, when working with synthetic data. # ## K-Means Clustering in Python # In[1]: import matplotlib.pyplot as plt from kneed import KneeLocator from sklearn.datasets import make_blobs from sklearn.cluster import KMeans from sklearn.metrics import silhouette_score from sklearn.preprocessing import StandardScaler # In[2]: features, true_labels = make_blobs( n_samples=200, centers=3, cluster_std=2.75, random_state=42 ) # In[3]: features[:5] # In[4]: true_labels[:5] # In[5]: scaler = StandardScaler() scaled_features = scaler.fit_transform(features) # In[6]: scaled_features[:5] # In[7]: kmeans = KMeans( init="random", n_clusters=3, n_init=10, max_iter=300, random_state=42 ) # In[8]: kmeans.fit(scaled_features) # In[9]: # The lowest SSE value kmeans.inertia_ # In[10]: # Final locations of the centroid kmeans.cluster_centers_ # In[11]: # The number of iterations required to converge kmeans.n_iter_ # In[12]: kmeans.labels_[:5] # ### How to Choose the Number of Clusters # In[13]: kmeans_kwargs = { "init": "random", "n_init": 10, "max_iter": 300, "random_state": 42, } # A list holds the SSE values for each k sse = [] for k in range(1, 11): kmeans = KMeans(n_clusters=k, **kmeans_kwargs) kmeans.fit(scaled_features) sse.append(kmeans.inertia_) # In[14]: plt.style.use("fivethirtyeight") plt.plot(range(1, 11), sse) plt.xticks(range(1, 11)) plt.xlabel("Number of Clusters") plt.ylabel("SSE") plt.show() # In[15]: kl = KneeLocator( range(1, 11), sse, curve="convex", direction="decreasing" ) # In[16]: kl.elbow # In[17]: # A list holds the silhouette coefficients for each k silhouette_coefficients = [] # Notice we start at 2 clusters for silhouette coefficient for k in range(2, 11): kmeans = KMeans(n_clusters=k, **kmeans_kwargs) kmeans.fit(scaled_features) score = silhouette_score(scaled_features, kmeans.labels_) silhouette_coefficients.append(score) # In[18]: plt.style.use("fivethirtyeight") plt.plot(range(2, 11), silhouette_coefficients) plt.xticks(range(2, 11)) plt.xlabel("Number of Clusters") plt.ylabel("Silhouette Coefficient") plt.show() # ## Advanced Clustering Evaluation # In[19]: from sklearn.cluster import DBSCAN from sklearn.datasets import make_moons from sklearn.metrics import adjusted_rand_score # In[20]: features, true_labels = make_moons( n_samples=250, noise=0.05, random_state=42 ) scaled_features = scaler.fit_transform(features) # In[21]: # Instantiate k-means and dbscan algorithms kmeans = KMeans(n_clusters=2) dbscan = DBSCAN(eps=0.3) # Fit the algorithms to the features kmeans.fit(scaled_features) dbscan.fit(scaled_features) # Compute the silhouette scores for each algorithm kmeans_silhouette = silhouette_score( scaled_features, kmeans.labels_ ).round(2) dbscan_silhouette = silhouette_score( scaled_features, dbscan.labels_ ).round (2) # In[22]: kmeans_silhouette # In[23]: dbscan_silhouette # In[24]: # Plot the data and cluster silhouette comparison fig, (ax1, ax2) = plt.subplots( 1, 2, figsize=(8, 6), sharex=True, sharey=True ) fig.suptitle(f"Clustering Algorithm Comparison: Crescents", fontsize=16) fte_colors = { 0: "#008fd5", 1: "#fc4f30", } # The k-means plot km_colors = [fte_colors[label] for label in kmeans.labels_] ax1.scatter(scaled_features[:, 0], scaled_features[:, 1], c=km_colors) ax1.set_title( f"k-means\nSilhouette: {kmeans_silhouette}", fontdict={"fontsize": 12} ) # The dbscan plot db_colors = [fte_colors[label] for label in dbscan.labels_] ax2.scatter(scaled_features[:, 0], scaled_features[:, 1], c=db_colors) ax2.set_title( f"DBSCAN\nSilhouette: {dbscan_silhouette}", fontdict={"fontsize": 12} ) plt.show() # In[25]: ari_kmeans = adjusted_rand_score(true_labels, kmeans.labels_) ari_dbscan = adjusted_rand_score(true_labels, dbscan.labels_) # In[26]: round(ari_kmeans, 2) # In[27]: round(ari_dbscan, 2)