#!/usr/bin/env python # coding: utf-8 # # Visualizing Topic clusters # # In this notebook, we will learn how to visualize topic clusters using dendrogram. Dendrogram is a tree-structured graph which can be used to visualize the result of a hierarchical clustering calculation. Hierarchical clustering puts individual data points into similarity groups, without prior knowledge of groups. We can use it to explore the topic models and see how the topics are connected to each other in a sequence of successive fusions or divisions that occur in the clustering process. # In[ ]: get_ipython().system("pip install plotly>=2.0.16 # 2.0.16 need for support 'hovertext' argument from create_dendrogram function") # In[1]: from gensim.models.ldamodel import LdaModel from gensim.corpora import Dictionary from gensim.parsing.preprocessing import remove_stopwords, strip_punctuation import numpy as np import pandas as pd import re import plotly.offline as py import plotly.graph_objs as go import plotly.figure_factory as ff py.init_notebook_mode() # # Train Model # # We'll use the [fake news dataset](https://www.kaggle.com/mrisdal/fake-news) from kaggle for this notebook. First step is to preprocess the data and train our topic model using LDA. You can refer to this [notebook](https://github.com/RaRe-Technologies/gensim/blob/develop/docs/notebooks/lda_training_tips.ipynb) also for tips and suggestions of pre-processing the text data, and how to train LDA model for getting good results. # In[2]: df_fake = pd.read_csv('fake.csv') df_fake[['title', 'text', 'language']].head() df_fake = df_fake.loc[(pd.notnull(df_fake.text)) & (df_fake.language == 'english')] # remove stopwords and punctuations def preprocess(row): return strip_punctuation(remove_stopwords(row.lower())) df_fake['text'] = df_fake['text'].apply(preprocess) # Convert data to required input format by LDA texts = [] for line in df_fake.text: lowered = line.lower() words = re.findall(r'\w+', lowered, flags=re.UNICODE|re.LOCALE) texts.append(words) # Create a dictionary representation of the documents. dictionary = Dictionary(texts) # Filter out words that occur less than 2 documents, or more than 30% of the documents. dictionary.filter_extremes(no_below=2, no_above=0.4) # Bag-of-words representation of the documents. corpus_fake = [dictionary.doc2bow(text) for text in texts] # In[ ]: lda_fake = LdaModel(corpus=corpus_fake, id2word=dictionary, num_topics=35, passes=30, chunksize=1500, iterations=200, alpha='auto') lda_fake.save('lda_35') # In[3]: lda_fake = LdaModel.load('lda_35') # # Basic Dendrogram # # Firstly, a distance matrix is calculated to store distance between every topic pair. These distances are then used ascendingly to cluster the topics together whose process is depicted by the dendrogram. # In[4]: from gensim.matutils import jensen_shannon from scipy import spatial as scs from scipy.cluster import hierarchy as sch from scipy.spatial.distance import pdist, squareform # get topic distributions topic_dist = lda_fake.state.get_lambda() # get topic terms num_words = 300 topic_terms = [{w for (w, _) in lda_fake.show_topic(topic, topn=num_words)} for topic in range(topic_dist.shape[0])] # no. of terms to display in annotation n_ann_terms = 10 # use Jensen-Shannon distance metric in dendrogram def js_dist(X): return pdist(X, lambda u, v: jensen_shannon(u, v)) # define method for distance calculation in clusters linkagefun=lambda x: sch.linkage(x, 'single') # calculate text annotations def text_annotation(topic_dist, topic_terms, n_ann_terms, linkagefun): # get dendrogram hierarchy data linkagefun = lambda x: sch.linkage(x, 'single') d = js_dist(topic_dist) Z = linkagefun(d) P = sch.dendrogram(Z, orientation="bottom", no_plot=True) # store topic no.(leaves) corresponding to the x-ticks in dendrogram x_ticks = np.arange(5, len(P['leaves']) * 10 + 5, 10) x_topic = dict(zip(P['leaves'], x_ticks)) # store {topic no.:topic terms} topic_vals = dict() for key, val in x_topic.items(): topic_vals[val] = (topic_terms[key], topic_terms[key]) text_annotations = [] # loop through every trace (scatter plot) in dendrogram for trace in P['icoord']: fst_topic = topic_vals[trace[0]] scnd_topic = topic_vals[trace[2]] # annotation for two ends of current trace pos_tokens_t1 = list(fst_topic[0])[:min(len(fst_topic[0]), n_ann_terms)] neg_tokens_t1 = list(fst_topic[1])[:min(len(fst_topic[1]), n_ann_terms)] pos_tokens_t4 = list(scnd_topic[0])[:min(len(scnd_topic[0]), n_ann_terms)] neg_tokens_t4 = list(scnd_topic[1])[:min(len(scnd_topic[1]), n_ann_terms)] t1 = "
".join((": ".join(("+++", str(pos_tokens_t1))), ": ".join(("---", str(neg_tokens_t1))))) t2 = t3 = () t4 = "
".join((": ".join(("+++", str(pos_tokens_t4))), ": ".join(("---", str(neg_tokens_t4))))) # show topic terms in leaves if trace[0] in x_ticks: t1 = str(list(topic_vals[trace[0]][0])[:n_ann_terms]) if trace[2] in x_ticks: t4 = str(list(topic_vals[trace[2]][0])[:n_ann_terms]) text_annotations.append([t1, t2, t3, t4]) # calculate intersecting/diff for upper level intersecting = fst_topic[0] & scnd_topic[0] different = fst_topic[0].symmetric_difference(scnd_topic[0]) center = (trace[0] + trace[2]) / 2 topic_vals[center] = (intersecting, different) # remove trace value after it is annotated topic_vals.pop(trace[0], None) topic_vals.pop(trace[2], None) return text_annotations # In[5]: # get text annotations annotation = text_annotation(topic_dist, topic_terms, n_ann_terms, linkagefun) # Plot dendrogram dendro = ff.create_dendrogram(topic_dist, distfun=js_dist, labels=range(1, 36), linkagefun=linkagefun, hovertext=annotation) dendro['layout'].update({'width': 1000, 'height': 600}) py.iplot(dendro) # The x-axis or the leaves of hierarchy represent the topics of our LDA model, y-axis is a measure of closeness of either individual topics or their cluster. Essentially, the y-axis level at which the branches merge (relative to the "root" of the tree) is related to their similarity. For ex., topic 4 and 30 are more similar to each other than to topic 32. In addition, topic 18 and 24 are more similar to 35 than topic 4 and 30 are to topic 32 as the height on which they merge is lower than the merge height of 4/30 to 32. # # Text annotations visible on hovering over the cluster nodes show the intersecting/different terms of it's two child nodes. Cluster node on first hierarchy level uses the topics on leaves directly to calculate intersecting/different terms, and the upper nodes assume the intersection(+++) as the topic terms of it's child node. # # This type of tree graph could help us see the high level cluster theme that might exist in our data as we can see the common/different terms of combined topics in a cluster head annotation. # ## Dendrogram with a Heatmap # # Now lets append the distance matrix of the topics below the dendrogram in form of heatmap so that we can see the exact distances between all pair of topics. # In[6]: # get text annotations annotation = text_annotation(topic_dist, topic_terms, n_ann_terms, linkagefun) # Initialize figure by creating upper dendrogram figure = ff.create_dendrogram(topic_dist, distfun=js_dist, labels=range(1, 36), linkagefun=linkagefun, hovertext=annotation) for i in range(len(figure['data'])): figure['data'][i]['yaxis'] = 'y2' # In[7]: # get distance matrix and it's topic annotations mdiff, annotation = lda_fake.diff(lda_fake, distance="jensen_shannon", normed=False) # get reordered topic list dendro_leaves = figure['layout']['xaxis']['ticktext'] dendro_leaves = [x - 1 for x in dendro_leaves] # reorder distance matrix heat_data = mdiff[dendro_leaves, :] heat_data = heat_data[:, dendro_leaves] # In[8]: # heatmap annotation annotation_html = [["+++ {}
--- {}".format(", ".join(int_tokens), ", ".join(diff_tokens)) for (int_tokens, diff_tokens) in row] for row in annotation] # plot heatmap of distance matrix heatmap = go.Data([ go.Heatmap( z=heat_data, colorscale='YIGnBu', text=annotation_html, hoverinfo='x+y+z+text' ) ]) heatmap[0]['x'] = figure['layout']['xaxis']['tickvals'] heatmap[0]['y'] = figure['layout']['xaxis']['tickvals'] # Add Heatmap Data to Figure figure['data'].extend(heatmap) dendro_leaves = [x + 1 for x in dendro_leaves] # Edit Layout figure['layout'].update({'width': 800, 'height': 800, 'showlegend':False, 'hovermode': 'closest', }) # Edit xaxis figure['layout']['xaxis'].update({'domain': [.25, 1], 'mirror': False, 'showgrid': False, 'showline': False, "showticklabels": True, "tickmode": "array", "ticktext": dendro_leaves, "tickvals": figure['layout']['xaxis']['tickvals'], 'zeroline': False, 'ticks': ""}) # Edit yaxis figure['layout']['yaxis'].update({'domain': [0, 0.75], 'mirror': False, 'showgrid': False, 'showline': False, "showticklabels": True, "tickmode": "array", "ticktext": dendro_leaves, "tickvals": figure['layout']['xaxis']['tickvals'], 'zeroline': False, 'ticks': ""}) # Edit yaxis2 figure['layout'].update({'yaxis2':{'domain': [0.75, 1], 'mirror': False, 'showgrid': False, 'showline': False, 'zeroline': False, 'showticklabels': False, 'ticks': ""}}) py.iplot(figure) # The heatmap lets us see the exact distance measure between any two topics in the z-value of their corresponding cell and also their intersecting or different terms in the +++/--- annotation. This could help see the distance between those topics also which are not directly connected in the dendrogram.