from string import punctuation
from nltk import RegexpTokenizer
from nltk.stem.porter import PorterStemmer
from nltk.corpus import stopwords
from sklearn.datasets import fetch_20newsgroups
newsgroups = fetch_20newsgroups()
eng_stopwords = set(stopwords.words('english'))
tokenizer = RegexpTokenizer('\s+', gaps=True)
stemmer = PorterStemmer()
translate_tab = {ord(p): u" " for p in punctuation}
def text2tokens(raw_text):
"""
Convert raw test to list of stemmed tokens
"""
clean_text = raw_text.lower().translate(translate_tab)
tokens = [token.strip() for token in tokenizer.tokenize(clean_text)]
tokens = [token for token in tokens if token not in eng_stopwords]
stemmed_tokens = [stemmer.stem(token) for token in tokens]
return filter(lambda token: len(token) > 2, stemmed_tokens) # skip short tokens
dataset = [text2tokens(txt) for txt in newsgroups['data']] # convert a documents to list of tokens
from gensim.corpora import Dictionary
dictionary = Dictionary(documents=dataset, prune_at=None)
dictionary.filter_extremes(no_below=5, no_above=0.3, keep_n=None) # use Dictionary to remove un-relevant tokens
dictionary.compactify()
d2b_dataset = [dictionary.doc2bow(doc) for doc in dataset] # convert list of tokens to bag of word representation
%%time
from gensim.models import LdaMulticore
num_topics = 15
lda_fst = LdaMulticore(corpus=d2b_dataset, num_topics=num_topics,
id2word=dictionary, workers=4, eval_every=None, passes=10, batch=True)
lda_snd = LdaMulticore(corpus=d2b_dataset, num_topics=num_topics,
id2word=dictionary, workers=4, eval_every=None, passes=20, batch=True)
CPU times: user 3min 29s, sys: 39.8 s, total: 4min 9s Wall time: 5min 2s
import plotly.offline as py
import plotly.graph_objs as go
py.init_notebook_mode()
def plot_difference(mdiff, title="", annotation=None):
"""
Helper function for plot difference between models
"""
annotation_html = None
if annotation is not None:
annotation_html = [["+++ {}<br>--- {}".format(", ".join(int_tokens),
", ".join(diff_tokens))
for (int_tokens, diff_tokens) in row]
for row in annotation]
data = go.Heatmap(z=mdiff, colorscale='RdBu', text=annotation_html)
layout = go.Layout(width=950, height=950, title=title,
xaxis=dict(title="topic"), yaxis=dict(title="topic"))
py.iplot(dict(data=[data], layout=layout))