%matplotlib inline
import pylab as pl
import numpy as np
# Some nice default configuration for plots
pl.rcParams['figure.figsize'] = 10, 7.5
pl.rcParams['axes.grid'] = True
The sklearn.feature_extraction.text.CountVectorizer
and sklearn.feature_extraction.text.TfidfVectorizer
classes suffer from a number of scalability issues that all stem from the internal usage of the vocabulary_
attribute (a Python dictionary) used to map the unicode string feature names to the integer feature indices.
The main scalability issues are:
vocabulary_
would be a shared state: complex synchronization and overheadvocabulary_
needs to be learned from the data: its size cannot be known before making one pass over the full datasetTo better understand the issue let's have a look at how the vocabulary_
attribute work. At fit
time the tokens of the corpus are uniquely indentified by a integer index and this mapping stored in the vocabulary:
from sklearn.feature_extraction.text import CountVectorizer
vectorizer = CountVectorizer(min_df=1)
vectorizer.fit([
"The cat sat on the mat.",
])
vectorizer.vocabulary_
/usr/lib/python2.7/dist-packages/scipy/stats/distributions.py:32: RuntimeWarning: numpy.ufunc size changed, may indicate binary incompatibility from . import vonmises_cython /usr/lib/python2.7/dist-packages/scipy/spatial/__init__.py:88: RuntimeWarning: numpy.ufunc size changed, may indicate binary incompatibility from .ckdtree import * /usr/lib/python2.7/dist-packages/scipy/spatial/__init__.py:89: RuntimeWarning: numpy.ufunc size changed, may indicate binary incompatibility from .qhull import * /usr/lib/python2.7/dist-packages/scipy/stats/stats.py:251: RuntimeWarning: numpy.ufunc size changed, may indicate binary incompatibility from ._rank import rankdata, tiecorrect /home/varoquau/dev/numpy/numpy/core/fromnumeric.py:2499: VisibleDeprecationWarning: `rank` is deprecated; use the `ndim` attribute or function instead. To find the rank of a matrix see `numpy.linalg.matrix_rank`. VisibleDeprecationWarning)
{u'cat': 0, u'mat': 1, u'on': 2, u'sat': 3, u'the': 4}
The vocabulary is used at transform
time to build the occurence matrix:
X = vectorizer.transform([
"The cat sat on the mat.",
"This cat is a nice cat.",
]).toarray()
print(len(vectorizer.vocabulary_))
print(vectorizer.get_feature_names())
print(X)
5 [u'cat', u'mat', u'on', u'sat', u'the'] [[1 1 1 1 2] [2 0 0 0 0]]
Let's refit with a slightly larger corpus:
vectorizer = CountVectorizer(min_df=1)
vectorizer.fit([
"The cat sat on the mat.",
"The quick brown fox jumps over the lazy dog.",
])
vectorizer.vocabulary_
{u'brown': 0, u'cat': 1, u'dog': 2, u'fox': 3, u'jumps': 4, u'lazy': 5, u'mat': 6, u'on': 7, u'over': 8, u'quick': 9, u'sat': 10, u'the': 11}
The vocabulary_
is the (logarithmically) growing with the size of the training corpus. Note that we could not have built the vocabularies in parallel on the 2 text documents as they share some words hence would require some kind of shared datastructure or synchronization barrier which is complicated to setup, especially if we want to distribute the processing on a cluster.
With this new vocabulary, the dimensionality of the output space is now larger:
X = vectorizer.transform([
"The cat sat on the mat.",
"This cat is a nice cat.",
]).toarray()
print(len(vectorizer.vocabulary_))
print(vectorizer.get_feature_names())
print(X)
12 [u'brown', u'cat', u'dog', u'fox', u'jumps', u'lazy', u'mat', u'on', u'over', u'quick', u'sat', u'the'] [[0 1 0 0 0 0 1 1 0 0 1 2] [0 2 0 0 0 0 0 0 0 0 0 0]]
To illustrate the scalabitiy issues of the vocabulary-based vectorizers, let's load a more reallistic dataset for a classical text classification task: sentiment analysis on tweets. The goald is to tell appart negative from positive tweets on a variety of topics.
Assuming that the ../fetch_data.py
script was run successfully the following files should be available:
import os
sentiment140_folder = os.path.join('datasets', 'sentiment140')
training_csv_file = os.path.join(sentiment140_folder, 'training.1600000.processed.noemoticon.csv')
testing_csv_file = os.path.join(sentiment140_folder, 'testdata.manual.2009.06.14.csv')
Those files were downloaded from the research archive of the http://www.sentiment140.com/ project. The first file was gathered using the twitter streaming API by running stream queries for the positive ":)" and negative ":(" emoticons to collect tweets that are explicitly positive or negative. To make the classification problem non-trivial, the emoticons were stripped out of the text in the CSV files:
!ls -lh datasets/sentiment140/training.1600000.processed.noemoticon.csv
-rw-r--r-- 1 varoquau varoquau 228M Jul 21 10:01 datasets/sentiment140/training.1600000.processed.noemoticon.csv
Let's parse the CSV files and load everything in memory. As loading everything can take up to 2GB, let's limit the collection to 100K tweets of each (positive and negative) out of the total of 1.6M tweets.
FIELDNAMES = ('polarity', 'id', 'date', 'query', 'author', 'text')
def read_csv(csv_file, fieldnames=FIELDNAMES, max_count=None,
n_partitions=1, partition_id=0):
import csv # put the import inside for use in IPython.parallel
texts = []
targets = []
with open(csv_file, 'rb') as f:
reader = csv.DictReader(f, fieldnames=fieldnames,
delimiter=',', quotechar='"')
pos_count, neg_count = 0, 0
for i, d in enumerate(reader):
if i % n_partitions != partition_id:
# Skip entry if not in the requested partition
continue
if d['polarity'] == '4':
if max_count and pos_count >= max_count / 2:
continue
pos_count += 1
texts.append(d['text'])
targets.append(1)
elif d['polarity'] == '0':
if max_count and neg_count >= max_count / 2:
continue
neg_count += 1
texts.append(d['text'])
targets.append(-1)
return texts, targets
%time text_train_all, target_train_all = read_csv(training_csv_file, max_count=200000)
CPU times: user 7.94 s, sys: 68.2 ms, total: 8.01 s Wall time: 8 s
len(text_train_all), len(target_train_all)
(200000, 200000)
Let's display the first samples:
for text in text_train_all[:3]:
print(text + "\n")
@switchfoot http://twitpic.com/2y1zl - Awww, that's a bummer. You shoulda got David Carr of Third Day to do it. ;D is upset that he can't update his Facebook by texting it... and might cry as a result School today also. Blah! @Kenichan I dived many times for the ball. Managed to save 50% The rest go out of bounds
print(target_train_all[:3])
[-1, -1, -1]
A polarity of "0" means negative while a polarity of "4" means positive. All the positive tweets are at the end of the file:
for text in text_train_all[-3:]:
print(text + "\n")
Okie doke!! Time for me to escape for the North while Massa's back is turned. Be on when I get home folks finished the lessons, hooray! Some ppl are just fucking KP0. Cb ! Stop asking me laa.. I love my boyfriend and thats it.
print(target_train_all[-3:])
[1, 1, 1]
Let's split the training CSV file into a smaller training set and a validation set with 100k random tweets each:
from sklearn.cross_validation import train_test_split
text_train_small, text_validation, target_train_small, target_validation = train_test_split(
text_train_all, target_train_all, test_size=.5, random_state=0)
len(text_train_small)
100000
# Let's make numpy arrays out of these
target_train_small = np.array(target_train_small)
target_validation = np.array(target_validation)
np.sum(target_train_small == -1), np.sum(target_train_small == 1)
(50195, 49805)
len(text_validation)
100000
np.sum(target_validation == -1), np.sum(target_validation == 1)
(49805, 50195)
Let's open the manually annotated tweet files. The evaluation set also has neutral tweets with a polarity of "2" which we ignore. We can build the final evaluation set with only the positive and negative tweets of the evaluation CSV file:
text_test_all, target_test_all = read_csv(testing_csv_file)
len(text_test_all), len(target_test_all)
(359, 359)
To workaround the limitations of the vocabulary-based vectorizers, one can use the hashing trick. Instead of building and storing an explicit mapping from the feature names to the feature indices in a Python dict, we can just use a hash function and a modulus operation:
from sklearn.utils.murmurhash import murmurhash3_bytes_u32
for word in "the cat sat on the mat".split():
print("{0} => {1}".format(
word, murmurhash3_bytes_u32(word, 0) % 2 ** 20))
the => 761698 cat => 300839 sat => 122804 on => 735689 the => 761698 mat => 122997
This mapping is completly stateless and the dimensionality of the output space is explicitly fixed in advance (here we use a modulo 2 ** 20
which means roughly 1M dimensions). The makes it possible to workaround the limitations of the vocabulary based vectorizer both for parallelizability and online / out-of-core learning.
The HashingVectorizer
class is an alternative to the TfidfVectorizer
class with use_idf=False
that internally uses the murmurhash hash function:
from sklearn.feature_extraction.text import HashingVectorizer
h_vectorizer = HashingVectorizer(charset='latin-1')
h_vectorizer
/home/varoquau/dev/scikit-learn/sklearn/feature_extraction/text.py:389: DeprecationWarning: The charset parameter is deprecated as of version 0.14 and will be removed in 0.16. Use encoding instead. DeprecationWarning)
HashingVectorizer(analyzer=u'word', binary=False, charset=None, charset_error=None, decode_error=u'strict', dtype=<type 'numpy.float64'>, encoding='latin-1', input=u'content', lowercase=True, n_features=1048576, ngram_range=(1, 1), non_negative=False, norm=u'l2', preprocessor=None, stop_words=None, strip_accents=None, token_pattern=u'(?u)\\b\\w\\w+\\b', tokenizer=None)
It shares the same "preprocessor", "tokenizer" and "analyzer" infrastructure:
analyzer = h_vectorizer.build_analyzer()
analyzer('This is a test sentence.')
[u'this', u'is', u'test', u'sentence']
We can vectorize our datasets into a scipy sparse matrix exactly as we would have done with the CountVectorizer
or TfidfVectorizer
, except that we can directly call the transform
method: there is no need to fit
as HashingVectorizer
is a stateless transformer:
%time X_train_small = h_vectorizer.transform(text_train_small)
CPU times: user 1.82 s, sys: 8.17 ms, total: 1.83 s Wall time: 1.82 s
The dimension of the output is fixed ahead of time to n_features=2 ** 20
by default (nearly 1M features) to minimize the rate of collision on most classification problem while having reasonably sized linear models (1M weights in the coef_
attribute):
X_train_small
<100000x1048576 sparse matrix of type '<type 'numpy.float64'>' with 1184803 stored elements in Compressed Sparse Row format>
As only the non-zero elements are stored, n_features
has little impact on the actual size of the data in memory. We can combine the hashing vectorizer with a Passive-Aggressive linear model in a pipeline:
from sklearn.linear_model import PassiveAggressiveClassifier
from sklearn.pipeline import Pipeline
h_pipeline = Pipeline((
('vec', HashingVectorizer(charset='latin-1')),
('clf', PassiveAggressiveClassifier(C=1, n_iter=1)),
))
%time h_pipeline.fit(text_train_small, target_train_small).score(text_validation, target_validation)
CPU times: user 3.96 s, sys: 32.6 ms, total: 3.99 s Wall time: 3.95 s
0.74768000000000001
Let's check that the score on the validation set is reasonably in line with the set of manually annotated tweets:
h_pipeline.score(text_test_all, target_test_all)
0.74930362116991645
As the text_train_small
dataset is not that big we can still use a vocabulary based vectorizer to check that the hashing collisions are not causing any significative performance drop on the validation set (WARNING this is twice as slow as the hashing vectorizer version, skip this cell if your computer is too slow):
from sklearn.feature_extraction.text import TfidfVectorizer
vocabulary_vec = TfidfVectorizer(charset='latin-1', use_idf=False)
vocabulary_pipeline = Pipeline((
('vec', vocabulary_vec),
('clf', PassiveAggressiveClassifier(C=1, n_iter=1)),
))
%time vocabulary_pipeline.fit(text_train_small, target_train_small).score(text_validation, target_validation)
CPU times: user 3.27 s, sys: 60.7 ms, total: 3.33 s Wall time: 3.29 s
/home/varoquau/dev/scikit-learn/sklearn/feature_extraction/text.py:620: DeprecationWarning: The charset parameter is deprecated as of version 0.14 and will be removed in 0.16. Use encoding instead. DeprecationWarning)
0.74802000000000002
We get almost the same score but almost twice as slower with also a big, slow to (un)pickle datastructure in memory:
len(vocabulary_vec.vocabulary_)
91405
More info and reference for the original papers on the Hashing Trick in the answers to this http://metaoptimize.com/qa question: What is the Hashing Trick?.
Out-of-Core learning is the task of training a machine learning model on a dataset that does not fit in the main memory. This requires the following conditions:
partial_fit
method in scikit-learn).Let us simulate an infinite tweeter stream that can generate batches of annotated tweet texts and there polarity. We can do this by recombining randomly pairs of positive or negative tweets from our fixed dataset:
from random import Random
class InfiniteStreamGenerator(object):
"""Simulate random polarity queries on the twitter streaming API"""
def __init__(self, texts, targets, seed=0, batchsize=100):
self.texts_pos = [text for text, target in zip(texts, targets)
if target > 0]
self.texts_neg = [text for text, target in zip(texts, targets)
if target <= 0]
self.rng = Random(seed)
self.batchsize = batchsize
def next_batch(self, batchsize=None):
batchsize = self.batchsize if batchsize is None else batchsize
texts, targets = [], []
for i in range(batchsize):
# Select the polarity randomly
target = self.rng.choice((-1, 1))
targets.append(target)
# Combine 2 random texts of the right polarity
pool = self.texts_pos if target > 0 else self.texts_neg
text = self.rng.choice(pool) + " " + self.rng.choice(pool)
texts.append(text)
return texts, targets
infinite_stream = InfiniteStreamGenerator(text_train_small, target_train_small)
texts_in_batch, targets_in_batch = infinite_stream.next_batch(batchsize=3)
for t in texts_in_batch:
print(t + "\n")
It's sunny outside, so @penguingirl74 and I are inside playing Gears 2 Co-Op just ate Lucky Me! Curly Spaghetti. Instant Spaghetti finally made right. @mirandafox Poor Princess Twitchy I was so shocked!! Just woke up feel crap @samkoh hahaha, but were you twittering and driving!? remind me never to ride in your car! @ithinkminh well, you have till Oct to get there
targets_in_batch
[1, -1, 1]
We can now use our infinte tweet source to train an online machine learning algorithm using the hashing vectorizer. Note the use of the partial_fit
method of the PassiveAggressiveClassifier
instance in place of the traditional call to the fit
method that needs access to the full training set.
From time to time, we evaluate the current predictive performance of the model on our validation set that is guaranteed to not overlap with the infinite training set source:
n_batches = 1000
validation_scores = []
training_set_size = []
# Build the vectorizer and the classifier
h_vectorizer = HashingVectorizer(charset='latin-1')
clf = PassiveAggressiveClassifier(C=1)
# Extract the features for the validation once and for all
X_validation = h_vectorizer.transform(text_validation)
classes = np.array([-1, 1])
n_samples = 0
for i in range(n_batches):
texts_in_batch, targets_in_batch = infinite_stream.next_batch()
n_samples += len(texts_in_batch)
# Vectorize the text documents in the batch
X_batch = h_vectorizer.transform(texts_in_batch)
# Incrementally train the model on the new batch
clf.partial_fit(X_batch, targets_in_batch, classes=classes)
if n_samples % 100 == 0:
# Compute the validation score of the current state of the model
score = clf.score(X_validation, target_validation)
validation_scores.append(score)
training_set_size.append(n_samples)
if i % 100 == 0:
print("n_samples: {0}, score: {1:.4f}".format(n_samples, score))
n_samples: 100, score: 0.5411 n_samples: 10100, score: 0.7215 n_samples: 20100, score: 0.7413 n_samples: 30100, score: 0.7536 n_samples: 40100, score: 0.7528 n_samples: 50100, score: 0.7381 n_samples: 60100, score: 0.7560 n_samples: 70100, score: 0.7535 n_samples: 80100, score: 0.7543 n_samples: 90100, score: 0.7575
We can now plot the collected validation score values, versus the number of samples generated by the infinite source and feed to the model:
pl.plot(training_set_size, validation_scores)
pl.ylim(0.5, 1)
pl.xlabel("Number of samples")
pl.ylabel("Validation score")
<matplotlib.text.Text at 0x8ea3d50>
Using the Hashing Vectorizer makes it possible to implement streaming and parallel text classification but can also introduce some issues:
HashingVectorizer
does not provide "Inverse Document Frequency" reweighting (lack of a use_idf=True
option).The collision issues can be controlled by increasing the n_features
parameters.
The IDF weighting might be reintroduced by appending a TfidfTransformer
instance on the output of the vectorizer. However computing the idf_
statistic used for the feature reweighting will require to do at least one additional pass over the training set before being able to start training the classifier: this breaks the online learning scheme.
The lack of inverse mapping (the get_feature_names()
method of TfidfVectorizer
) is even harder to workaround. That would require extending the HashingVectorizer
class to add a "trace" mode to record the mapping of the most important features to provide statistical debugging information.
In the mean time to debug feature extraction issues, it is recommended to use TfidfVectorizer(use_idf=False)
on a small-ish subset of the dataset to simulate a HashingVectorizer()
instance that have the get_feature_names()
method and no collision issues.