Naive Bayes from scratch

In [1]:
import re
import glob
from pathlib import Path
from random import shuffle
from math import exp, log
from collections import defaultdict, Counter
from typing import NamedTuple, List, Set, Tuple
In [2]:
# Ensure that we have a `data` directory we use to store downloaded data
!mkdir -p data
data_dir: Path = Path('data')
In [3]:
# We're using the "Enron Spam" data set
!wget -nc -P data http://nlp.cs.aueb.gr/software_and_datasets/Enron-Spam/preprocessed/enron1.tar.gz
--2020-02-09 12:03:06--  http://nlp.cs.aueb.gr/software_and_datasets/Enron-Spam/preprocessed/enron1.tar.gz
Resolving nlp.cs.aueb.gr (nlp.cs.aueb.gr)... 195.251.248.252
Connecting to nlp.cs.aueb.gr (nlp.cs.aueb.gr)|195.251.248.252|:80... connected.
HTTP request sent, awaiting response... 200 OK
Length: 1802573 (1.7M) [application/x-gzip]
Saving to: ‘data/enron1.tar.gz’

enron1.tar.gz       100%[===================>]   1.72M   920KB/s    in 1.9s    

2020-02-09 12:03:08 (920 KB/s) - ‘data/enron1.tar.gz’ saved [1802573/1802573]

In [4]:
!tar -xzf data/enron1.tar.gz -C data
In [5]:
# The data set has 2 directories: One for `spam` messages, one for `ham` messages
spam_data_path: Path = data_dir / 'enron1' / 'spam'
ham_data_path: Path = data_dir / 'enron1' / 'ham'
In [6]:
# Our data container for `spam` and `ham` messages
class Message(NamedTuple):
    text: str
    is_spam: bool
In [7]:
# Globbing for all the `.txt` files in both (`spam` and `ham`) directories
spam_message_paths: List[str] = glob.glob(str(spam_data_path / '*.txt'))
ham_message_paths: List[str] = glob.glob(str(ham_data_path / '*.txt'))

message_paths: List[str] = spam_message_paths + ham_message_paths
message_paths[:5]
Out[7]:
['data/enron1/spam/4743.2005-06-25.GP.spam.txt',
 'data/enron1/spam/1309.2004-06-08.GP.spam.txt',
 'data/enron1/spam/0726.2004-03-26.GP.spam.txt',
 'data/enron1/spam/0202.2004-01-13.GP.spam.txt',
 'data/enron1/spam/3988.2005-03-06.GP.spam.txt']
In [8]:
# The list which eventually contains all the parsed Enron `spam` and `ham` messages
messages: List[Message] = []
In [9]:
# Open every file individually, turn it into a `Message` and append it to our `messages` list
for path in message_paths:
    with open(path, errors='ignore') as file:
        is_spam: bool = True if 'spam' in path else False
        # We're only interested in the subject for the time being        
        text: str = file.readline().replace('Subject:', '').strip()
        messages.append(Message(text, is_spam))
In [10]:
shuffle(messages)
messages[:5]
Out[10]:
[Message(text='january production estimate', is_spam=False),
 Message(text='re : your code # 5 g 6878', is_spam=True),
 Message(text='account # 20367 s tue , 28 jun 2005 11 : 41 : 41 - 0800', is_spam=True),
 Message(text='congratulations', is_spam=True),
 Message(text='fw : hpl imbalance payback', is_spam=False)]
In [11]:
len(messages)
Out[11]:
5172
In [12]:
# Given a string, normalize and extract all words with length greater than 2
def tokenize(text: str) -> Set[str]:
    words: List[str] = []
    for word in re.findall(r'[A-Za-z0-9\']+', text):
        if len(word) >= 2:
            words.append(word.lower())
    return set(words)

assert tokenize('Is this a text? If so, Tokenize this text!...') == {'is', 'this', 'text', 'if', 'so', 'tokenize'}
In [13]:
tokenize(messages[0].text)
Out[13]:
{'estimate', 'january', 'production'}
In [14]:
# Split the list of messages into a `train` and `test` set (defaults to 80/20 train/test split)
def train_test_split(messages: List[Message], pct=0.8) -> Tuple[List[Message], List[Message]]:
    shuffle(messages)
    num_train = int(round(len(messages) * pct, 0))
    return messages[:num_train], messages[num_train:]

assert len(train_test_split(messages)[0]) + len(train_test_split(messages)[1]) == len(messages)
In [15]:
# The Naive Bayes classifier
class NaiveBayes:
    def __init__(self, k=1) -> None:
        # `k` is the smoothening factor
        self._k: int = k
        self._num_spam_messages: int = 0
        self._num_ham_messages: int = 0
        self._num_word_in_spam: Dict[int] = defaultdict(int)
        self._num_word_in_ham: Dict[int] = defaultdict(int)
        self._spam_words: Set[str] = set()
        self._ham_words: Set[str] = set()
        self._words: Set[str] = set()

    # Iterate through the given messages and gather the necessary statistics
    def train(self, messages: List[Message]) -> None:
        msg: Message
        token: str
        for msg in messages:
            tokens: Set[str] = tokenize(msg.text)
            self._words.update(tokens)
            if msg.is_spam:
                self._num_spam_messages += 1
                self._spam_words.update(tokens)
                for token in tokens:
                    self._num_word_in_spam[token] += 1
            else:
                self._num_ham_messages += 1
                self._ham_words.update(tokens)
                for token in tokens:
                    self._num_word_in_ham[token] += 1                
    
    # Probability of `word` being spam
    def _p_word_spam(self, word: str) -> float:
        return (self._k + self._num_word_in_spam[word]) / ((2 * self._k) + self._num_spam_messages)
    
    # Probability of `word` being ham
    def _p_word_ham(self, word: str) -> float:
        return (self._k + self._num_word_in_ham[word]) / ((2 * self._k) + self._num_ham_messages)
    
    # Given a `text`, how likely is it spam?
    def predict(self, text: str) -> float:
        text_words: Set[str] = tokenize(text)
        log_p_spam: float = 0.0
        log_p_ham: float = 0.0

        for word in self._words:
            p_spam: float = self._p_word_spam(word)
            p_ham: float = self._p_word_ham(word)
            if word in text_words:
                log_p_spam += log(p_spam)
                log_p_ham += log(p_ham)
            else:
                log_p_spam += log(1 - p_spam)
                log_p_ham += log(1 - p_ham)

        p_if_spam: float = exp(log_p_spam)
        p_if_ham: float = exp(log_p_ham)
        return p_if_spam / (p_if_spam + p_if_ham)

# Tests
def test_naive_bayes():
    messages: List[Message] = [
        Message('Spam message', is_spam=True),
        Message('Ham message', is_spam=False),
        Message('Ham message about Spam', is_spam=False)]
    
    nb: NaiveBayes = NaiveBayes()
    nb.train(messages)
    
    assert nb._num_spam_messages == 1
    assert nb._num_ham_messages == 2
    assert nb._spam_words == {'spam', 'message'}
    assert nb._ham_words == {'ham', 'message', 'about', 'spam'}
    assert nb._num_word_in_spam == {'spam': 1, 'message': 1}
    assert nb._num_word_in_ham == {'ham': 2, 'message': 2, 'about': 1, 'spam': 1}
    assert nb._words == {'spam', 'message', 'ham', 'about'}

    # Our test message
    text: str = 'A spam message'
    
    # Reminder: The `_words` we iterater over are: {'spam', 'message', 'ham', 'about'}
    
    # Calculate how spammy the `text` might be
    p_if_spam: float = exp(sum([
        log(     (1 + 1) / ((2 * 1) + 1)),  # `spam` (also in `text`)
        log(     (1 + 1) / ((2 * 1) + 1)),  # `message` (also in `text`)
        log(1 - ((1 + 0) / ((2 * 1) + 1))), # `ham` (NOT in `text`)
        log(1 - ((1 + 0) / ((2 * 1) + 1))), # `about` (NOT in `text`)
    ]))
    
    # Calculate how hammy the `text` might be
    p_if_ham: float = exp(sum([
        log(     (1 + 1)  / ((2 * 1) + 2)),  # `spam` (also in `text`)
        log(     (1 + 2)  / ((2 * 1) + 2)),  # `message` (also in `text`)
        log(1 - ((1 + 2)  / ((2 * 1) + 2))), # `ham` (NOT in `text`)
        log(1 - ((1 + 1)  / ((2 * 1) + 2))), # `about` (NOT in `text`)
    ]))
    
    p_spam: float = p_if_spam / (p_if_spam + p_if_ham)
    
    assert p_spam == nb.predict(text)

test_naive_bayes()
In [16]:
train: List[Message]
test: List[Message]

# Splitting our Enron messages into a `train` and `test` set
train, test = train_test_split(messages)
In [17]:
# Train our Naive Bayes classifier with the `train` set
nb: NaiveBayes = NaiveBayes()
nb.train(train)

print(f'Spam messages in training data: {nb._num_spam_messages}')
print(f'Ham messages in training data: {nb._num_ham_messages}')
print(f'Most spammy words: {Counter(nb._num_word_in_spam).most_common(20)}')
Spam messages in training data: 1227
Ham messages in training data: 2911
Most spammy words: [('you', 115), ('the', 104), ('your', 104), ('for', 86), ('to', 83), ('re', 81), ('on', 56), ('and', 51), ('get', 48), ('is', 48), ('in', 43), ('with', 40), ('of', 38), ('it', 35), ('at', 35), ('online', 34), ('all', 33), ('from', 33), ('this', 32), ('new', 31)]
In [18]:
# Grabbing all the spam messages from our `test` set
spam_messages: List[Message] = [item for item in test if item.is_spam]
spam_messages[:5]
Out[18]:
[Message(text="a witch . i don ' t", is_spam=True),
 Message(text='active and strong', is_spam=True),
 Message(text='get great prices on medications', is_spam=True),
 Message(text='', is_spam=True),
 Message(text='popular software at low low prices . misunderstand developments', is_spam=True)]
In [19]:
# Using our trained Naive Bayes classifier to classify a spam message
message: str = spam_messages[10].text
    
print(f'Predicting likelihood of "{message}" being spam.')
nb.predict(message)
Predicting likelihood of "get your hand clock repliacs todday carson" being spam.
Out[19]:
0.9884313222593173
In [20]:
# Grabbing all the ham messages from our `test` set
ham_messages: List[Message] = [item for item in test if not item.is_spam]
ham_messages[:5]
Out[20]:
[Message(text='new update for buybacks', is_spam=False),
 Message(text='enron and blockbuster to launch entertainment on - demand service', is_spam=False),
 Message(text='re : astros web site comments', is_spam=False),
 Message(text='re : formosa meter # : 1000', is_spam=False),
 Message(text='re : deal extension for 11 / 21 / 2000 for 98 - 439', is_spam=False)]
In [21]:
# Using our trained Naive Bayes classifier to classify a ham message
message: str = ham_messages[10].text

print(f'Predicting likelihood of "{text}" being spam.')
nb.predict(message)
Predicting likelihood of "associate & analyst mid - year 2001 prc process" being spam.
Out[21]:
5.3089147140900964e-05