!pip3 install plotly
Requirement already satisfied: plotly in /usr/local/lib/python3.6/dist-packages (4.5.0) Requirement already satisfied: retrying>=1.3.3 in /usr/local/lib/python3.6/dist-packages (from plotly) (1.3.3) Requirement already satisfied: six in /usr/local/lib/python3.6/dist-packages (from plotly) (1.14.0)
import csv
from pathlib import Path
from math import sqrt
from operator import attrgetter
from collections import Counter
from typing import NamedTuple, List
from plotly import express as px
from plotly import graph_objects as go
# Ensure that we have a `data` directory we use to store downloaded data
!mkdir -p data
data_dir: Path = Path('data')
# Downloading the Iris data set
!wget -nc -P data https://archive.ics.uci.edu/ml/machine-learning-databases/iris/iris.data
File ‘data/iris.data’ already there; not retrieving.
# The structure of the Iris data set is as follows:
# Sepal Length, Sepal Width, Petal Length, Petal Width, Class
!tail data/iris.data
6.9,3.1,5.1,2.3,Iris-virginica 5.8,2.7,5.1,1.9,Iris-virginica 6.8,3.2,5.9,2.3,Iris-virginica 6.7,3.3,5.7,2.5,Iris-virginica 6.7,3.0,5.2,2.3,Iris-virginica 6.3,2.5,5.0,1.9,Iris-virginica 6.5,3.0,5.2,2.0,Iris-virginica 6.2,3.4,5.4,2.3,Iris-virginica 5.9,3.0,5.1,1.8,Iris-virginica
# Defining the index-lookup equivalents here since we turn the CSV data into indexable data later on
sepal_length_idx: int = 0
sepal_width_idx: int = 1
petal_length_idx: int = 2
petal_width_idx: int = 3
label_idx: int = 4
# Create the Python path pointing to the `iris.data` file
iris_data_path: Path = data_dir / 'iris.data'
# Our data container for individual Iris data set items
class IrisLabeled(NamedTuple):
label: str
measurements: List[float]
labeled_data: List[IrisLabeled] = []
# Read the `iris.data` file and parse it line-by-line
with open(iris_data_path) as csv_file:
reader = csv.reader(csv_file, delimiter=',')
for row in reader:
# Check if the given row is a valid iris datapoint
if len(row) == 5:
label: str = row[-1]
measurements: List[float] = [float(num) for num in row[:-1]]
labeled_data.append(IrisLabeled(label, measurements))
len(labeled_data)
150
# Computing values for plotting
# The petal length
xs: List[float] = [iris.measurements[petal_length_idx] for iris in labeled_data]
# The petal width
ys: List[float] = [iris.measurements[petal_width_idx] for iris in labeled_data]
# Classes
text: List[str] = [iris.label for iris in labeled_data]
fig = px.scatter(x=xs, y=ys, color=text, hover_name=text, labels={'x': 'Petal Length', 'y': 'Petal Width'})
fig.show()
# Our made up measurement we want to classify via KNN
new_measurement: List[float] = [7, 3, 4.8, 1.5]
# Re-plotting the Iris data with our new_measurement added to it
fig = px.scatter(x=xs, y=ys, color=text, hover_name=text, labels={'x': 'Petal Length', 'y': 'Petal Width'})
fig.add_annotation(
go.layout.Annotation(
x=new_measurement[petal_length_idx],
y=new_measurement[petal_width_idx],
text="The measurement we want to classify")
)
fig.update_annotations(dict(
xref="x",
yref="y",
showarrow=True,
arrowhead=7,
ax=0,
ay=-40,
borderwidth=2,
borderpad=4,
bgcolor="#c3c3c3"
))
fig.show()
# Given a list of labels, what's the most used label in that list
# NOTE: The labels are already sorted (e.g. by distance from nearest to furthest)
def majority_vote(labels: List[str]) -> str:
counted: Counter = Counter(labels)
winner: List[str] = []
max_num: int = 0
most_common: List[Tuple[str, int]]
for most_common in counted.most_common():
label: str = most_common[0]
num: int = most_common[1]
if num < max_num:
break
max_num = num
winner.append(label)
if len(winner) > 1:
return majority_vote(labels[:-1])
return winner[0]
assert majority_vote(['a', 'b', 'b', 'c']) == 'b'
assert majority_vote(['a', 'b', 'b', 'a']) == 'b'
assert majority_vote(['a', 'a', 'b', 'b', 'c']) == 'a'
# Computes the Euclidean distance between two vectors
# See: https://en.wikipedia.org/wiki/Euclidean_distance
def distance(x: List[float], y: List[float]) -> float:
assert len(x) == len(y)
return sqrt(sum((x[i] - y[i]) ** 2 for i in range(len(x))))
assert distance([1, 2, 3, 4], [5, 6, 7, 8]) == 8
# The KNN implementation
def knn(labeled_data: List[IrisLabeled], new_measurement, k: int = 5) -> IrisLabeled:
# First, compute all the distances between the points in the labeled data and the new measurement
class Distance(NamedTuple):
label: str
distance: float
distances: List[Distance] = [Distance(data.label, distance(new_measurement, data.measurements))
for data in labeled_data]
distances = sorted(distances, key=attrgetter('distance'))
# Second, find the K nearest neighbors of the new measurement based on the computed distances
# The new measurements label is the most used label of its k neighbors
labels = [distance.label for distance in distances][:k]
label: str = majority_vote(labels)
return IrisLabeled(label, new_measurement)
assert knn(labeled_data, new_measurement, 5) == IrisLabeled('Iris-versicolor', [7, 3, 4.8, 1.5])
knn(labeled_data, new_measurement, 5)
IrisLabeled(label='Iris-versicolor', measurements=[7, 3, 4.8, 1.5])