k-NN from scratch

In [1]:
!pip3 install plotly
Requirement already satisfied: plotly in /usr/local/lib/python3.6/dist-packages (4.5.0)
Requirement already satisfied: retrying>=1.3.3 in /usr/local/lib/python3.6/dist-packages (from plotly) (1.3.3)
Requirement already satisfied: six in /usr/local/lib/python3.6/dist-packages (from plotly) (1.14.0)
In [2]:
import csv
from pathlib import Path
from math import sqrt
from operator import attrgetter
from collections import Counter
from typing import NamedTuple, List
from plotly import express as px
from plotly import graph_objects as go
In [3]:
# Ensure that we have a `data` directory we use to store downloaded data
!mkdir -p data
data_dir: Path = Path('data')
In [4]:
# Downloading the Iris data set
!wget -nc -P data https://archive.ics.uci.edu/ml/machine-learning-databases/iris/iris.data 
File ‘data/iris.data’ already there; not retrieving.

In [5]:
# The structure of the Iris data set is as follows:
# Sepal Length, Sepal Width, Petal Length, Petal Width, Class
!tail data/iris.data
6.9,3.1,5.1,2.3,Iris-virginica
5.8,2.7,5.1,1.9,Iris-virginica
6.8,3.2,5.9,2.3,Iris-virginica
6.7,3.3,5.7,2.5,Iris-virginica
6.7,3.0,5.2,2.3,Iris-virginica
6.3,2.5,5.0,1.9,Iris-virginica
6.5,3.0,5.2,2.0,Iris-virginica
6.2,3.4,5.4,2.3,Iris-virginica
5.9,3.0,5.1,1.8,Iris-virginica

In [6]:
# Defining the index-lookup equivalents here since we turn the CSV data into indexable data later on
sepal_length_idx: int = 0
sepal_width_idx: int = 1
petal_length_idx: int = 2
petal_width_idx: int = 3
label_idx: int = 4
In [7]:
# Create the Python path pointing to the `iris.data` file
iris_data_path: Path = data_dir / 'iris.data'
In [8]:
# Our data container for individual Iris data set items
class IrisLabeled(NamedTuple):
    label: str
    measurements: List[float]
In [9]:
labeled_data: List[IrisLabeled] = []
In [10]:
# Read the `iris.data` file and parse it line-by-line
with open(iris_data_path) as csv_file:
    reader = csv.reader(csv_file, delimiter=',')
    for row in reader:
        # Check if the given row is a valid iris datapoint
        if len(row) == 5:
            label: str = row[-1]
            measurements: List[float] = [float(num) for num in row[:-1]]
            labeled_data.append(IrisLabeled(label, measurements))
In [11]:
len(labeled_data)
Out[11]:
150
In [12]:
# Computing values for plotting

# The petal length
xs: List[float] = [iris.measurements[petal_length_idx] for iris in labeled_data]
# The petal width
ys: List[float] = [iris.measurements[petal_width_idx] for iris in labeled_data]
# Classes
text: List[str] = [iris.label for iris in labeled_data]
In [13]:
fig = px.scatter(x=xs, y=ys, color=text, hover_name=text, labels={'x': 'Petal Length', 'y': 'Petal Width'})
fig.show()