Decision Trees from scratch

In [1]:
import csv
from pathlib import Path
from copy import deepcopy
from typing import List, Tuple, Dict, NamedTuple, Any
from collections import Counter, defaultdict
In [2]:
# Ensure that we have a `data` directory we use to store downloaded data
!mkdir -p data
data_dir: Path = Path('data')
In [3]:
# Downloading the "Golf" data set
!wget -O "data/golf.csv" -nc -P data
--2020-02-23 10:52:54--
Resolving (,,, ...
Connecting to (||:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 373 [text/plain]
Saving to: ‘data/golf.csv’

golf.csv            100%[===================>]     373  --.-KB/s    in 0s      

2020-02-23 10:52:55 (11.5 MB/s) - ‘data/golf.csv’ saved [373/373]

In [4]:
!head -n 5 data/golf.csv
In [5]:
# Create the Python path pointing to the `golf.csv` file
golf_data_path: Path = data_dir / 'golf.csv'
In [6]:
# Evey entry in our data set is represented as a `DataPoint`
class DataPoint(NamedTuple):
    outlook: str
    temp: str
    humidity: str
    windy: bool
    play: bool
In [7]:
# Open the file, iterate over every row, create a `DataPoint` and append it to a list
data_points: List[DataPoint] = []

with open(golf_data_path) as csv_file:
    reader = csv.reader(csv_file, delimiter=',')
    next(reader, None)
    for row in reader:
        outlook: str = row[0].lower()
        temp: str = row[1].lower()
        humidty: str = row[2].lower()
        windy: bool = True if row[3].lower() == 't' else False
        play: bool = True if row[4].lower() == 'yes' else False
        data_point: DataPoint = DataPoint(outlook, temp, humidty, windy, play)
In [8]:
[DataPoint(outlook='rainy', temp='hot', humidity='high', windy=False, play=False),
 DataPoint(outlook='rainy', temp='hot', humidity='high', windy=True, play=False),
 DataPoint(outlook='overcast', temp='hot', humidity='high', windy=False, play=True),
 DataPoint(outlook='sunny', temp='mild', humidity='high', windy=False, play=True),
 DataPoint(outlook='sunny', temp='cool', humidity='normal', windy=False, play=True)]
In [9]:
# Calculate the Gini impurity for a list of values
# See:
def gini(data: List[Any]) -> float:
    counter: Counter = Counter(data)
    classes: List[Any] = list(counter.keys())
    num_items: int = len(data)
    result: float = 0
    item: Any
    for item in classes:
        p_i: float = counter[item] / num_items
        result += p_i * (1 - p_i)
    return result

assert gini(['one', 'one']) == 0
assert gini(['one', 'two']) == 0.5
assert gini(['one', 'two', 'one', 'two']) == 0.5
assert 0.8 < gini(['one', 'two', 'three', 'four', 'five']) < 0.81
In [10]:
# Helper function to filter down a list of data points by a `feature` and its `value`
def filter_by_feature(data_points: List[DataPoint], *args) -> List[DataPoint]:
    result: List[DataPoint] = deepcopy(data_points)
    for arg in args:
        feature: str = arg[0]
        value: Any = arg[1]
        result = [data_point for data_point in result if getattr(data_point, feature) == value]
    return result

assert len(filter_by_feature(data_points, ('outlook', 'sunny'))) == 5
assert len(filter_by_feature(data_points, ('outlook', 'sunny'), ('temp', 'mild'))) == 3
assert len(filter_by_feature(data_points, ('outlook', 'sunny'), ('temp', 'mild'), ('humidity', 'high'))) == 2
In [11]:
# Helper function to extract the values the `feature` in question can assume
def feature_values(data_points: List[DataPoint], feature: str) -> List[Any]:
    return list(set([getattr(dp, feature) for dp in data_points]))

assert feature_values(data_points, 'outlook').sort() == ['sunny', 'overcast', 'rainy'].sort()
In [12]:
# Calculate the weighted sum of the Gini impurities for the `feature` in question
def gini_for_feature(data_points: List[DataPoint], feature: str, label: str = 'play') -> float:
    total: int = len(data_points)
    # Distinct values the `feature` in question can assume
    dist_values: List[Any] = feature_values(data_points, feature)
    # Calculate all the Gini impurities for every possible value a `feature` can assume
    ginis: Dict[str, float] = defaultdict(float)
    ratios: Dict[str, float] = defaultdict(float)
    for value in dist_values:
        filtered: List[DataPoint] = filter_by_feature(data_points, (feature, value))
        labels: List[Any] = [getattr(dp, label) for dp in filtered]
        ginis[value] = gini(labels)
        # We use the ratio when we compute the weighted sum later on
        ratios[value] = len(labels) / total
    # Calculate the weighted sum of the `feature` in question
    weighted_sum: float = sum([ratios[key] * value for key, value in ginis.items()])
    return weighted_sum

assert 0.34 < gini_for_feature(data_points, 'outlook') < 0.35
assert 0.44 < gini_for_feature(data_points, 'temp') < 0.45
assert 0.36 < gini_for_feature(data_points, 'humidity') < 0.37
assert 0.42 < gini_for_feature(data_points, 'windy') < 0.43
In [13]:
# NOTE: We can't use type hinting here due to cyclic dependencies

# A `Node` has a `value` and optional out `Edge`s
class Node:
    def __init__(self, value):
        self._value = value
        self._edges = []

    def __repr__(self):
        if len(self._edges):
            return f'{self._value} --> {self._edges}'
            return f'{self._value}'
    def value(self):
        return self._value

    def add_edge(self, edge):
    def find_edge(self, value):
        return next(edge for edge in self._edges if edge.value == value)

# An `Edge` has a value and points to a `Node`
class Edge:
    def __init__(self, value):
        self._value = value
        self._node = None

    def __repr__(self):
        return f'{self._value} --> {self._node}'
    def value(self):
        return self._value
    def node(self):
        return self._node
    def node(self, node):
        self._node = node
In [14]:
# Recursively build a tree via the CART algorithm based on our list of data points
def build_tree(data_points: List[DataPoint], features: List[str], label: str = 'play') -> Node:
    # Ensure that the `features` list doesn't include the `label`
    features.remove(label) if label in features else None

    # Compute the weighted Gini impurity for each `feature` given that we'd split the tree at the `feature` in question
    weighted_sums: Dict[str, float] = defaultdict(float)
    for feature in features:
        weighted_sums[feature] = gini_for_feature(data_points, feature)

    # If all the weighted Gini impurities are 0.0 we create a final `Node` (leaf) with the given `label`
    weighted_sum_vals: List[float] = list(weighted_sums.values())
    if (float(0) in weighted_sum_vals and len(set(weighted_sum_vals)) == 1):
        label = getattr(data_points[0], 'play')
        return Node(label)    
    # The `Node` with the most minimal weighted Gini impurity is the one we should use for splitting
    min_feature = min(weighted_sums, key=weighted_sums.get)
    node: Node = Node(min_feature)
    # Remove the `feature` we've processed from the list of `features` which still need to be processed
    reduced_features: List[str] = deepcopy(features)

    # Next up we build the `Edge`s which are the values our `min_feature` can assume
    for value in feature_values(data_points, min_feature):
        # Create a new `Edge` which contains a potential `value` of our `min_feature`
        edge: Edge = Edge(value)
        # Add the `Edge` to our `Node`
        # Filter down the data points we'll use next since we've just processed the set which includes our `min_feature`
        reduced_data_points: List[DataPoint] = filter_by_feature(data_points, (min_feature, value))
        # This `Edge` points to the new `Node` (subtree) we'll create through recursion
        edge.node = build_tree(reduced_data_points, reduced_features)

    # Return the `Node` (our `min_feature`)
    return node
In [15]:
# Create a new tree based on the loaded data points
features: List[str] = list(DataPoint._fields)

tree: Node = build_tree(data_points, features)
outlook --> [overcast --> True, sunny --> windy --> [False --> True, True --> False], rainy --> humidity --> [normal --> True, high --> False]]
In [16]:
# Traverse the tree based on the query trying to find a leaf with the prediction
def predict(tree: Node, query: List[Tuple[str, str]]) -> Any:
    node: Node = deepcopy(tree)
    for item in query:
        feature: str = item[0]
        value: Any = item[1]
        if node.value != feature:
        edge: Edge = node.find_edge(value)
        if not edge:
            raise Exception(f'Edge with value "{value}" not found on Node "{node}"')
        node: Node = edge.node
    return node

assert predict(tree, [('outlook', 'overcast')]) != True
assert predict(tree, [('outlook', 'sunny'), ('windy', False)]) != True
assert predict(tree, [('outlook', 'sunny'), ('windy', True)]) != False
assert predict(tree, [('outlook', 'rainy'), ('humidity', 'high')]) != False
assert predict(tree, [('outlook', 'rainy'), ('humidity', 'normal')]) != True
assert predict(tree, [('outlook', 'rainy'), ('windy', True), ('humidity', 'normal')]) != True
In [17]:
predict(tree, [('outlook', 'rainy'), ('humidity', 'normal')])