#!/usr/bin/env python # coding: utf-8 # # k-NN from scratch # In[1]: get_ipython().system('pip3 install plotly') # In[2]: import csv from pathlib import Path from math import sqrt from operator import attrgetter from collections import Counter from typing import NamedTuple, List from plotly import express as px from plotly import graph_objects as go # In[3]: # Ensure that we have a `data` directory we use to store downloaded data get_ipython().system('mkdir -p data') data_dir: Path = Path('data') # In[4]: # Downloading the Iris data set get_ipython().system('wget -nc -P data https://archive.ics.uci.edu/ml/machine-learning-databases/iris/iris.data') # In[5]: # The structure of the Iris data set is as follows: # Sepal Length, Sepal Width, Petal Length, Petal Width, Class get_ipython().system('tail data/iris.data') # In[6]: # Defining the index-lookup equivalents here since we turn the CSV data into indexable data later on sepal_length_idx: int = 0 sepal_width_idx: int = 1 petal_length_idx: int = 2 petal_width_idx: int = 3 label_idx: int = 4 # In[7]: # Create the Python path pointing to the `iris.data` file iris_data_path: Path = data_dir / 'iris.data' # In[8]: # Our data container for individual Iris data set items class IrisLabeled(NamedTuple): label: str measurements: List[float] # In[9]: labeled_data: List[IrisLabeled] = [] # In[10]: # Read the `iris.data` file and parse it line-by-line with open(iris_data_path) as csv_file: reader = csv.reader(csv_file, delimiter=',') for row in reader: # Check if the given row is a valid iris datapoint if len(row) == 5: label: str = row[-1] measurements: List[float] = [float(num) for num in row[:-1]] labeled_data.append(IrisLabeled(label, measurements)) # In[11]: len(labeled_data) # In[12]: # Computing values for plotting # The petal length xs: List[float] = [iris.measurements[petal_length_idx] for iris in labeled_data] # The petal width ys: List[float] = [iris.measurements[petal_width_idx] for iris in labeled_data] # Classes text: List[str] = [iris.label for iris in labeled_data] # In[13]: fig = px.scatter(x=xs, y=ys, color=text, hover_name=text, labels={'x': 'Petal Length', 'y': 'Petal Width'}) fig.show() # In[14]: # Our made up measurement we want to classify via KNN new_measurement: List[float] = [7, 3, 4.8, 1.5] # In[15]: # Re-plotting the Iris data with our new_measurement added to it fig = px.scatter(x=xs, y=ys, color=text, hover_name=text, labels={'x': 'Petal Length', 'y': 'Petal Width'}) fig.add_annotation( go.layout.Annotation( x=new_measurement[petal_length_idx], y=new_measurement[petal_width_idx], text="The measurement we want to classify") ) fig.update_annotations(dict( xref="x", yref="y", showarrow=True, arrowhead=7, ax=0, ay=-40, borderwidth=2, borderpad=4, bgcolor="#c3c3c3" )) fig.show() # In[16]: # Given a list of labels, what's the most used label in that list # NOTE: The labels are already sorted (e.g. by distance from nearest to furthest) def majority_vote(labels: List[str]) -> str: counted: Counter = Counter(labels) winner: List[str] = [] max_num: int = 0 most_common: List[Tuple[str, int]] for most_common in counted.most_common(): label: str = most_common[0] num: int = most_common[1] if num < max_num: break max_num = num winner.append(label) if len(winner) > 1: return majority_vote(labels[:-1]) return winner[0] assert majority_vote(['a', 'b', 'b', 'c']) == 'b' assert majority_vote(['a', 'b', 'b', 'a']) == 'b' assert majority_vote(['a', 'a', 'b', 'b', 'c']) == 'a' # In[17]: # Computes the Euclidean distance between two vectors # See: https://en.wikipedia.org/wiki/Euclidean_distance def distance(x: List[float], y: List[float]) -> float: assert len(x) == len(y) return sqrt(sum((x[i] - y[i]) ** 2 for i in range(len(x)))) assert distance([1, 2, 3, 4], [5, 6, 7, 8]) == 8 # In[18]: # The KNN implementation def knn(labeled_data: List[IrisLabeled], new_measurement, k: int = 5) -> IrisLabeled: # First, compute all the distances between the points in the labeled data and the new measurement class Distance(NamedTuple): label: str distance: float distances: List[Distance] = [Distance(data.label, distance(new_measurement, data.measurements)) for data in labeled_data] distances = sorted(distances, key=attrgetter('distance')) # Second, find the K nearest neighbors of the new measurement based on the computed distances # The new measurements label is the most used label of its k neighbors labels = [distance.label for distance in distances][:k] label: str = majority_vote(labels) return IrisLabeled(label, new_measurement) assert knn(labeled_data, new_measurement, 5) == IrisLabeled('Iris-versicolor', [7, 3, 4.8, 1.5]) # In[19]: knn(labeled_data, new_measurement, 5)