This is a walkthrough of some code for a simple decision-tree learner.
We'll import a dataset, where each row is a mushroom, and each column specifies attributes of that mushroom. We want to guess whether the mushroom is poisonous or not based on the other attributes.
NOTE: This is not a robust version of the algorithm: for example, it can only handle discrete/nominal attributes, it doesn't handle overfitting, etc. The purpose is for the code to be relatively simple and easy to understand.
import pandas as pd
from pandas import DataFrame
df_shroom = DataFrame.from_csv('../datasets/mushroom_data.csv')
df_shroom
class | cap-shape | cap-surface | cap-color | bruises? | odor | gill-attachment | gill-spacing | gill-size | gill-color | ... | stalk-surface-below-ring | stalk-color-above-ring | stalk-color-below-ring | veil-type | veil-color | ring-number | ring-type | spore-print-color | population | habitat | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
0 | poisonous | convex | smooth | brown | bruises | pungent | free | close | narrow | black | ... | smooth | white | white | partial | white | one | pendant | black | scattered | urban |
1 | edible | convex | smooth | yellow | bruises | almond | free | close | broad | black | ... | smooth | white | white | partial | white | one | pendant | brown | numerous | grasses |
2 | edible | bell | smooth | white | bruises | anise | free | close | broad | brown | ... | smooth | white | white | partial | white | one | pendant | brown | numerous | meadows |
3 | poisonous | convex | scaly | white | bruises | pungent | free | close | narrow | brown | ... | smooth | white | white | partial | white | one | pendant | black | scattered | urban |
4 | edible | convex | smooth | gray | no | none | free | crowded | broad | black | ... | smooth | white | white | partial | white | one | evanescent | brown | abundant | grasses |
5 | edible | convex | scaly | yellow | bruises | almond | free | close | broad | brown | ... | smooth | white | white | partial | white | one | pendant | black | numerous | grasses |
6 | edible | bell | smooth | white | bruises | almond | free | close | broad | gray | ... | smooth | white | white | partial | white | one | pendant | black | numerous | meadows |
7 | edible | bell | scaly | white | bruises | anise | free | close | broad | brown | ... | smooth | white | white | partial | white | one | pendant | brown | scattered | meadows |
8 | poisonous | convex | scaly | white | bruises | pungent | free | close | narrow | pink | ... | smooth | white | white | partial | white | one | pendant | black | several | grasses |
9 | edible | bell | smooth | yellow | bruises | almond | free | close | broad | gray | ... | smooth | white | white | partial | white | one | pendant | black | scattered | meadows |
10 | edible | convex | scaly | yellow | bruises | anise | free | close | broad | gray | ... | smooth | white | white | partial | white | one | pendant | brown | numerous | grasses |
11 | edible | convex | scaly | yellow | bruises | almond | free | close | broad | brown | ... | smooth | white | white | partial | white | one | pendant | black | scattered | meadows |
12 | edible | bell | smooth | yellow | bruises | almond | free | close | broad | white | ... | smooth | white | white | partial | white | one | pendant | brown | scattered | grasses |
13 | poisonous | convex | scaly | white | bruises | pungent | free | close | narrow | black | ... | smooth | white | white | partial | white | one | pendant | brown | several | urban |
14 | edible | convex | fibrous | brown | no | none | free | crowded | broad | brown | ... | fibrous | white | white | partial | white | one | evanescent | black | abundant | grasses |
15 | edible | sunken | fibrous | gray | no | none | free | close | narrow | black | ... | smooth | white | white | partial | white | one | pendant | brown | solitary | urban |
16 | edible | flat | fibrous | white | no | none | free | crowded | broad | black | ... | smooth | white | white | partial | white | one | evanescent | brown | abundant | grasses |
17 | poisonous | convex | smooth | brown | bruises | pungent | free | close | narrow | brown | ... | smooth | white | white | partial | white | one | pendant | black | scattered | grasses |
18 | poisonous | convex | scaly | white | bruises | pungent | free | close | narrow | brown | ... | smooth | white | white | partial | white | one | pendant | brown | scattered | urban |
19 | poisonous | convex | smooth | brown | bruises | pungent | free | close | narrow | black | ... | smooth | white | white | partial | white | one | pendant | brown | scattered | urban |
20 | edible | bell | smooth | yellow | bruises | almond | free | close | broad | black | ... | smooth | white | white | partial | white | one | pendant | brown | scattered | meadows |
21 | poisonous | convex | scaly | brown | bruises | pungent | free | close | narrow | brown | ... | smooth | white | white | partial | white | one | pendant | brown | several | grasses |
22 | edible | bell | scaly | yellow | bruises | anise | free | close | broad | black | ... | smooth | white | white | partial | white | one | pendant | brown | scattered | meadows |
23 | edible | bell | scaly | white | bruises | almond | free | close | broad | white | ... | smooth | white | white | partial | white | one | pendant | brown | numerous | meadows |
24 | edible | bell | smooth | white | bruises | anise | free | close | broad | gray | ... | smooth | white | white | partial | white | one | pendant | black | scattered | meadows |
25 | poisonous | flat | smooth | white | bruises | pungent | free | close | narrow | brown | ... | smooth | white | white | partial | white | one | pendant | brown | several | grasses |
26 | edible | convex | scaly | yellow | bruises | almond | free | close | broad | brown | ... | smooth | white | white | partial | white | one | pendant | brown | numerous | meadows |
27 | edible | convex | scaly | white | bruises | anise | free | close | broad | white | ... | smooth | white | white | partial | white | one | pendant | brown | numerous | meadows |
28 | edible | flat | fibrous | brown | no | none | free | close | narrow | black | ... | smooth | white | white | partial | white | one | pendant | black | solitary | urban |
29 | edible | convex | smooth | yellow | bruises | almond | free | crowded | narrow | brown | ... | smooth | white | white | partial | white | one | pendant | brown | several | woods |
... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... |
5614 | edible | flat | smooth | gray | bruises | none | free | close | broad | white | ... | smooth | white | white | partial | white | two | pendant | white | solitary | paths |
5615 | edible | convex | smooth | pink | bruises | none | free | close | broad | white | ... | smooth | white | white | partial | white | two | pendant | white | several | paths |
5616 | poisonous | conical | scaly | yellow | no | none | free | crowded | narrow | white | ... | scaly | yellow | yellow | partial | yellow | one | evanescent | white | clustered | leaves |
5617 | poisonous | knobbed | scaly | red | no | musty | attached | close | broad | yellow | ... | scaly | cinnamon | cinnamon | partial | white | none | none | white | clustered | woods |
5618 | poisonous | convex | scaly | cinnamon | no | musty | free | close | broad | white | ... | scaly | cinnamon | cinnamon | partial | white | none | none | white | clustered | woods |
5619 | edible | bell | scaly | brown | no | none | free | close | broad | white | ... | scaly | brown | brown | partial | white | two | pendant | white | solitary | woods |
5620 | edible | convex | scaly | brown | bruises | none | free | close | broad | white | ... | smooth | white | white | partial | white | two | pendant | white | several | paths |
5621 | poisonous | convex | scaly | brown | no | musty | attached | close | broad | yellow | ... | scaly | cinnamon | cinnamon | partial | white | none | none | white | clustered | woods |
5622 | poisonous | knobbed | scaly | yellow | no | none | free | crowded | narrow | white | ... | scaly | yellow | yellow | partial | yellow | one | evanescent | white | clustered | leaves |
5623 | edible | knobbed | smooth | brown | no | none | free | close | broad | white | ... | scaly | brown | brown | partial | white | two | pendant | white | solitary | woods |
5624 | poisonous | flat | scaly | cinnamon | no | musty | attached | close | broad | white | ... | scaly | cinnamon | cinnamon | partial | white | none | none | white | clustered | woods |
5625 | poisonous | convex | scaly | brown | no | musty | free | close | broad | white | ... | scaly | cinnamon | cinnamon | partial | white | none | none | white | clustered | woods |
5626 | poisonous | knobbed | scaly | red | no | musty | free | close | broad | white | ... | scaly | cinnamon | cinnamon | partial | white | none | none | white | clustered | woods |
5627 | edible | convex | scaly | brown | no | none | free | close | broad | white | ... | scaly | brown | brown | partial | white | two | pendant | white | solitary | woods |
5628 | edible | flat | scaly | gray | bruises | none | free | close | broad | white | ... | smooth | white | white | partial | white | two | pendant | white | several | paths |
5629 | poisonous | flat | scaly | brown | no | musty | attached | close | broad | yellow | ... | scaly | cinnamon | cinnamon | partial | white | none | none | white | clustered | woods |
5630 | edible | flat | scaly | pink | bruises | none | free | close | broad | white | ... | smooth | white | white | partial | white | two | pendant | white | several | paths |
5631 | edible | flat | smooth | brown | no | none | free | close | broad | white | ... | scaly | brown | brown | partial | white | two | pendant | white | solitary | paths |
5632 | edible | flat | scaly | gray | bruises | none | free | close | broad | white | ... | smooth | white | white | partial | white | two | pendant | white | solitary | paths |
5633 | poisonous | knobbed | scaly | red | no | musty | attached | close | broad | white | ... | scaly | cinnamon | cinnamon | partial | white | none | none | white | clustered | woods |
5634 | edible | convex | scaly | cinnamon | bruises | none | free | close | broad | white | ... | smooth | white | white | partial | white | two | pendant | white | several | paths |
5635 | edible | flat | smooth | cinnamon | bruises | none | free | close | broad | white | ... | smooth | white | white | partial | white | two | pendant | white | several | paths |
5636 | edible | convex | scaly | brown | bruises | none | free | close | broad | white | ... | smooth | white | white | partial | white | two | pendant | white | solitary | paths |
5637 | poisonous | knobbed | scaly | cinnamon | no | musty | attached | close | broad | yellow | ... | scaly | cinnamon | cinnamon | partial | white | none | none | white | clustered | woods |
5638 | edible | flat | smooth | brown | no | none | free | close | broad | white | ... | scaly | brown | brown | partial | white | two | pendant | white | solitary | woods |
5639 | edible | bell | scaly | brown | no | none | free | close | broad | white | ... | scaly | brown | brown | partial | white | two | pendant | white | solitary | paths |
5640 | edible | convex | scaly | brown | no | none | free | close | broad | white | ... | scaly | brown | brown | partial | white | two | pendant | white | solitary | paths |
5641 | edible | convex | scaly | gray | bruises | none | free | close | broad | white | ... | smooth | white | white | partial | white | two | pendant | white | solitary | paths |
5642 | poisonous | convex | scaly | cinnamon | no | musty | free | close | broad | yellow | ... | scaly | cinnamon | cinnamon | partial | white | none | none | white | clustered | woods |
5643 | poisonous | flat | scaly | cinnamon | no | musty | attached | close | broad | yellow | ... | scaly | cinnamon | cinnamon | partial | white | none | none | white | clustered | woods |
5644 rows × 23 columns
The decision tree algorithm works by looking at all the attributes, and picking the 'best' one to split the data on, then running recursively on the split data.
The way we pick the 'best' attribute to split on is by picking the attribute that decreases 'entropy'.
What's entropy? Recall that we're trying to classify mushrooms as poisonous or not. Entropy is a value that will be low if a group of mushrooms mostly has the same class (all poisonous or all edible) and high if a group of mushrooms varies in their classes (half poisonous and half edible). So every split of the data that minimizes entropy is a split that does a good job classifying.
def entropy(probs):
'''
Takes a list of probabilities and calculates their entropy
'''
import math
return sum( [-prob*math.log(prob, 2) for prob in probs] )
def entropy_of_list(a_list):
'''
Takes a list of items with discrete values (e.g., poisonous, edible)
and returns the entropy for those items.
'''
from collections import Counter
# Tally Up:
cnt = Counter(x for x in a_list)
# Convert to Proportion
num_instances = len(a_list)*1.0
probs = [x / num_instances for x in cnt.values()]
# Calculate Entropy:
return entropy(probs)
# The initial entropy of the poisonous/not attribute for our dataset.
total_entropy = entropy_of_list(df_shroom['class'])
print total_entropy
0.959441337353
In order to decide which attribute to split on, we want to quantify how each attribute decreases the entropy.
We do this in a fairly intuitive way: we split our dataset by the possible values of an attribute, then do a weighted sum of the entropies for each of these split datasets, weighted by how big that sub-dataset is.
We'll make a function that quantifies the decrease in entropy, or conversely, the gain in information.
def information_gain(df, split_attribute_name, target_attribute_name, trace=0):
'''
Takes a DataFrame of attributes, and quantifies the entropy of a target
attribute after performing a split along the values of another attribute.
'''
# Split Data by Possible Vals of Attribute:
df_split = df.groupby(split_attribute_name)
# Calculate Entropy for Target Attribute, as well as Proportion of Obs in Each Data-Split
nobs = len(df.index) * 1.0
df_agg_ent = df_split.agg({target_attribute_name : [entropy_of_list, lambda x: len(x)/nobs] })[target_attribute_name]
df_agg_ent.columns = ['Entropy', 'PropObservations']
if trace: # helps understand what fxn is doing:
print df_agg_ent
# Calculate Information Gain:
new_entropy = sum( df_agg_ent['Entropy'] * df_agg_ent['PropObservations'] )
old_entropy = entropy_of_list(df[target_attribute_name])
return old_entropy-new_entropy
print '\nExample: Info-gain for best attribute is ' + str( information_gain(df_shroom, 'odor', 'class') )
Example: Info-gain for best attribute is 0.859670435885
Now we'll write the decision tree algorithm itself, which is called "ID3".
def id3(df, target_attribute_name, attribute_names, default_class=None):
## Tally target attribute:
from collections import Counter
cnt = Counter(x for x in df[target_attribute_name])
## First check: Is this split of the dataset homogeneous?
# (e.g., all mushrooms in this set are poisonous)
# if yes, return that homogenous label (e.g., 'poisonous')
if len(cnt) == 1:
return cnt.keys()[0]
## Second check: Is this split of the dataset empty?
# if yes, return a default value
elif df.empty or (not attribute_names):
return default_class
## Otherwise: This dataset is ready to be divvied up!
else:
# Get Default Value for next recursive call of this function:
index_of_max = cnt.values().index(max(cnt.values()))
default_class = cnt.keys()[index_of_max] # most common value of target attribute in dataset
# Choose Best Attribute to split on:
gainz = [information_gain(df, attr, target_attribute_name) for attr in attribute_names]
index_of_max = gainz.index(max(gainz))
best_attr = attribute_names[index_of_max]
# Create an empty tree, to be populated in a moment
tree = {best_attr:{}}
remaining_attribute_names = [i for i in attribute_names if i != best_attr]
# Split dataset
# On each split, recursively call this algorithm.
# populate the empty tree with subtrees, which
# are the result of the recursive call
for attr_val, data_subset in df.groupby(best_attr):
subtree = id3(data_subset,
target_attribute_name,
remaining_attribute_names,
default_class)
tree[best_attr][attr_val] = subtree
return tree
# Get Predictor Names (all but 'class')
attribute_names = list(df_shroom.columns)
attribute_names.remove('class')
# Run Algorithm:
from pprint import pprint
tree = id3(df_shroom, 'class', attribute_names)
pprint(tree)
{'odor': {'almond': 'edible', 'anise': 'edible', 'creosote': 'poisonous', 'foul': 'poisonous', 'musty': 'poisonous', 'none': {'spore-print-color': {'black': 'edible', 'brown': 'edible', 'green': 'poisonous', 'white': {'ring-type': {'evanescent': {'stalk-surface-above-ring': {'fibrous': 'edible', 'scaly': 'poisonous', 'smooth': 'edible'}}, 'pendant': {'stalk-surface-above-ring': {'scaly': 'edible', 'smooth': {'stalk-surface-below-ring': {'smooth': {'stalk-color-above-ring': {'white': {'stalk-color-below-ring': {'white': {'stalk-shape': {'enlarging': {'gill-color': {'white': {'cap-color': {'brown': 'edible', 'cinnamon': 'edible', 'gray': 'edible', 'pink': 'edible', 'white': 'poisonous'}}}}}}}}}}}}}}}}}}, 'pungent': 'poisonous'}}
Let's make sure our resulting tree accurately predicts the class, based on the features.
Below is a 'classify' algorithm that takes an instance and classifies it based on the tree.
def classify(instance, tree, default=None):
attribute = tree.keys()[0]
if instance[attribute] in tree[attribute].keys():
result = tree[attribute][instance[attribute]]
if isinstance(result, dict): # this is a tree, delve deeper
return classify(instance, result)
else:
return result # this is a label
else:
return default
df_shroom['predicted'] = df_shroom.apply(classify, axis=1, args=(tree,'poisonous') )
# classify func allows for a default arg: when tree doesn't have answer for a particular
# combitation of attribute-values, we can use 'poisonous' as the default guess (better safe than sorry!)
print 'Accuracy is ' + str( sum(df_shroom['class']==df_shroom['predicted'] ) / (1.0*len(df_shroom.index)) )
df_shroom[['class', 'predicted']]
Accuracy is 1.0
class | predicted | |
---|---|---|
0 | poisonous | poisonous |
1 | edible | edible |
2 | edible | edible |
3 | poisonous | poisonous |
4 | edible | edible |
5 | edible | edible |
6 | edible | edible |
7 | edible | edible |
8 | poisonous | poisonous |
9 | edible | edible |
10 | edible | edible |
11 | edible | edible |
12 | edible | edible |
13 | poisonous | poisonous |
14 | edible | edible |
15 | edible | edible |
16 | edible | edible |
17 | poisonous | poisonous |
18 | poisonous | poisonous |
19 | poisonous | poisonous |
20 | edible | edible |
21 | poisonous | poisonous |
22 | edible | edible |
23 | edible | edible |
24 | edible | edible |
25 | poisonous | poisonous |
26 | edible | edible |
27 | edible | edible |
28 | edible | edible |
29 | edible | edible |
... | ... | ... |
5614 | edible | edible |
5615 | edible | edible |
5616 | poisonous | poisonous |
5617 | poisonous | poisonous |
5618 | poisonous | poisonous |
5619 | edible | edible |
5620 | edible | edible |
5621 | poisonous | poisonous |
5622 | poisonous | poisonous |
5623 | edible | edible |
5624 | poisonous | poisonous |
5625 | poisonous | poisonous |
5626 | poisonous | poisonous |
5627 | edible | edible |
5628 | edible | edible |
5629 | poisonous | poisonous |
5630 | edible | edible |
5631 | edible | edible |
5632 | edible | edible |
5633 | poisonous | poisonous |
5634 | edible | edible |
5635 | edible | edible |
5636 | edible | edible |
5637 | poisonous | poisonous |
5638 | edible | edible |
5639 | edible | edible |
5640 | edible | edible |
5641 | edible | edible |
5642 | poisonous | poisonous |
5643 | poisonous | poisonous |
5644 rows × 2 columns
Of course, a more accurate assessement of the algorithm is to train it on a subset of the data, then test it on a different subset.
training_data = df_shroom.iloc[1:-1000] # all but last thousand instances
test_data = df_shroom.iloc[-1000:] # just the last thousand
train_tree = id3(training_data, 'class', attribute_names)
test_data['predicted2'] = test_data.apply( # <---- test_data source
classify,
axis=1,
args=(train_tree,'poisonous') ) # <---- train_data tree
print 'Accuracy is ' + str( sum(test_data['class']==test_data['predicted2'] ) / (1.0*len(test_data.index)) )
Accuracy is 0.944