#!/usr/bin/env python # coding: utf-8 # # Visualizing Topic clusters # # In this notebook, we will learn how to visualize topic clusters using dendrogram. Dendrogram is a tree-structured graph which can be used to visualize the result of a hierarchical clustering calculation. Hierarchical clustering puts individual data points into similarity groups, without prior knowledge of groups. We can use it to explore the topic models and see how the topics are connected to each other in a sequence of successive fusions or divisions that occur in the clustering process. # In[1]: from gensim.models.ldamodel import LdaModel from gensim.corpora import Dictionary from gensim.parsing.preprocessing import remove_stopwords, strip_punctuation import numpy as np import pandas as pd import re import plotly.offline as py import plotly.graph_objs as go py.init_notebook_mode() # # Train Model # # We'll use the [fake news dataset](https://www.kaggle.com/mrisdal/fake-news) from kaggle for this notebook. First step is to preprocess the data and train our topic model using LDA. You can refer to this [notebook](https://github.com/RaRe-Technologies/gensim/blob/develop/docs/notebooks/lda_training_tips.ipynb) also for tips and suggestions of pre-processing the text data, and how to train LDA model for getting good results. # In[2]: df_fake = pd.read_csv('fake.csv') df_fake[['title', 'text', 'language']].head() df_fake = df_fake.loc[(pd.notnull(df_fake.text)) & (df_fake.language == 'english')] # remove stopwords and punctuations def preprocess(row): return strip_punctuation(remove_stopwords(row.lower())) df_fake['text'] = df_fake['text'].apply(preprocess) # Convert data to required input format by LDA texts = [] for line in df_fake.text: lowered = line.lower() words = re.findall(r'\w+', lowered, flags=re.UNICODE|re.LOCALE) texts.append(words) # Create a dictionary representation of the documents. dictionary = Dictionary(texts) # Filter out words that occur less than 2 documents, or more than 30% of the documents. dictionary.filter_extremes(no_below=2, no_above=0.4) # Bag-of-words representation of the documents. corpus_fake = [dictionary.doc2bow(text) for text in texts] # In[ ]: lda_fake = LdaModel(corpus=corpus_fake, id2word=dictionary, num_topics=35, passes=30, chunksize=1500, iterations=200, alpha='auto') lda_fake.save('lda_35') # In[3]: lda_fake = LdaModel.load('lda_35') # # Basic Dendrogram # # Firstly, a distance matrix is calculated to store distance between every topic pair. These distances are then used ascendingly to cluster the topics together whose process is depicted by the dendrogram. # In[4]: # This input cell contains the modified code from Plotly[1]. # It can be removed after PR (https://github.com/plotly/plotly.py/pull/807) gets merged. # [1] https://github.com/plotly/plotly.py/blob/master/plotly/figure_factory/_dendrogram.py from collections import OrderedDict from plotly import exceptions, optional_imports from plotly.graph_objs import graph_objs # Optional imports, may be None for users that only use our core functionality. np = optional_imports.get_module('numpy') scp = optional_imports.get_module('scipy') sch = optional_imports.get_module('scipy.cluster.hierarchy') scs = optional_imports.get_module('scipy.spatial') def create_dendrogram(X, orientation="bottom", labels=None, colorscale=None, distfun=None, linkagefun=lambda x: sch.linkage(x, 'single'), annotation=None): """ BETA function that returns a dendrogram Plotly figure object. :param (ndarray) X: Matrix of observations as array of arrays :param (str) orientation: 'top', 'right', 'bottom', or 'left' :param (list) labels: List of axis category labels(observation labels) :param (list) colorscale: Optional colorscale for dendrogram tree :param (function) distfun: Function to compute the pairwise distance from the observations :param (function) linkagefun: Function to compute the linkage matrix from the pairwise distances clusters Example 1: Simple bottom oriented dendrogram ``` import plotly.plotly as py from plotly.figure_factory import create_dendrogram import numpy as np X = np.random.rand(10,10) dendro = create_dendrogram(X) plot_url = py.plot(dendro, filename='simple-dendrogram') ``` Example 2: Dendrogram to put on the left of the heatmap ``` import plotly.plotly as py from plotly.figure_factory import create_dendrogram import numpy as np X = np.random.rand(5,5) names = ['Jack', 'Oxana', 'John', 'Chelsea', 'Mark'] dendro = create_dendrogram(X, orientation='right', labels=names) dendro['layout'].update({'width':700, 'height':500}) py.iplot(dendro, filename='vertical-dendrogram') ``` Example 3: Dendrogram with Pandas ``` import plotly.plotly as py from plotly.figure_factory import create_dendrogram import numpy as np import pandas as pd Index= ['A','B','C','D','E','F','G','H','I','J'] df = pd.DataFrame(abs(np.random.randn(10, 10)), index=Index) fig = create_dendrogram(df, labels=Index) url = py.plot(fig, filename='pandas-dendrogram') ``` """ if not scp or not scs or not sch: raise ImportError("FigureFactory.create_dendrogram requires scipy, \ scipy.spatial and scipy.hierarchy") s = X.shape if len(s) != 2: exceptions.PlotlyError("X should be 2-dimensional array.") if distfun is None: distfun = scs.distance.pdist dendrogram = _Dendrogram(X, orientation, labels, colorscale, distfun=distfun, linkagefun=linkagefun, annotation=annotation) return {'layout': dendrogram.layout, 'data': dendrogram.data} class _Dendrogram(object): """Refer to FigureFactory.create_dendrogram() for docstring.""" def __init__(self, X, orientation='bottom', labels=None, colorscale=None, width="100%", height="100%", xaxis='xaxis', yaxis='yaxis', distfun=None, linkagefun=lambda x: sch.linkage(x, 'single'), annotation=None): self.orientation = orientation self.labels = labels self.xaxis = xaxis self.yaxis = yaxis self.data = [] self.leaves = [] self.sign = {self.xaxis: 1, self.yaxis: 1} self.layout = {self.xaxis: {}, self.yaxis: {}} if self.orientation in ['left', 'bottom']: self.sign[self.xaxis] = 1 else: self.sign[self.xaxis] = -1 if self.orientation in ['right', 'bottom']: self.sign[self.yaxis] = 1 else: self.sign[self.yaxis] = -1 if distfun is None: distfun = scs.distance.pdist (dd_traces, xvals, yvals, ordered_labels, leaves) = self.get_dendrogram_traces(X, colorscale, distfun, linkagefun, annotation) self.labels = ordered_labels self.leaves = leaves yvals_flat = yvals.flatten() xvals_flat = xvals.flatten() self.zero_vals = [] for i in range(len(yvals_flat)): if yvals_flat[i] == 0.0 and xvals_flat[i] not in self.zero_vals: self.zero_vals.append(xvals_flat[i]) self.zero_vals.sort() self.layout = self.set_figure_layout(width, height) self.data = graph_objs.Data(dd_traces) def get_color_dict(self, colorscale): """ Returns colorscale used for dendrogram tree clusters. :param (list) colorscale: Colors to use for the plot in rgb format. :rtype (dict): A dict of default colors mapped to the user colorscale. """ # These are the color codes returned for dendrograms # We're replacing them with nicer colors d = {'r': 'red', 'g': 'green', 'b': 'blue', 'c': 'cyan', 'm': 'magenta', 'y': 'yellow', 'k': 'black', 'w': 'white'} default_colors = OrderedDict(sorted(d.items(), key=lambda t: t[0])) if colorscale is None: colorscale = [ 'rgb(0,116,217)', # blue 'rgb(35,205,205)', # cyan 'rgb(61,153,112)', # green 'rgb(40,35,35)', # black 'rgb(133,20,75)', # magenta 'rgb(255,65,54)', # red 'rgb(255,255,255)', # white 'rgb(255,220,0)'] # yellow for i in range(len(default_colors.keys())): k = list(default_colors.keys())[i] # PY3 won't index keys if i < len(colorscale): default_colors[k] = colorscale[i] return default_colors def set_axis_layout(self, axis_key): """ Sets and returns default axis object for dendrogram figure. :param (str) axis_key: E.g., 'xaxis', 'xaxis1', 'yaxis', yaxis1', etc. :rtype (dict): An axis_key dictionary with set parameters. """ axis_defaults = { 'type': 'linear', 'ticks': 'outside', 'mirror': 'allticks', 'rangemode': 'tozero', 'showticklabels': True, 'zeroline': False, 'showgrid': False, 'showline': True, } if len(self.labels) != 0: axis_key_labels = self.xaxis if self.orientation in ['left', 'right']: axis_key_labels = self.yaxis if axis_key_labels not in self.layout: self.layout[axis_key_labels] = {} self.layout[axis_key_labels]['tickvals'] = \ [zv*self.sign[axis_key] for zv in self.zero_vals] self.layout[axis_key_labels]['ticktext'] = self.labels self.layout[axis_key_labels]['tickmode'] = 'array' self.layout[axis_key].update(axis_defaults) return self.layout[axis_key] def set_figure_layout(self, width, height): """ Sets and returns default layout object for dendrogram figure. """ self.layout.update({ 'showlegend': False, 'autosize': False, 'hovermode': 'closest', 'width': width, 'height': height }) self.set_axis_layout(self.xaxis) self.set_axis_layout(self.yaxis) return self.layout def get_dendrogram_traces(self, X, colorscale, distfun, linkagefun, annotation): """ Calculates all the elements needed for plotting a dendrogram. :param (ndarray) X: Matrix of observations as array of arrays :param (list) colorscale: Color scale for dendrogram tree clusters :param (function) distfun: Function to compute the pairwise distance from the observations :param (function) linkagefun: Function to compute the linkage matrix from the pairwise distances :rtype (tuple): Contains all the traces in the following order: (a) trace_list: List of Plotly trace objects for dendrogram tree (b) icoord: All X points of the dendrogram tree as array of arrays with length 4 (c) dcoord: All Y points of the dendrogram tree as array of arrays with length 4 (d) ordered_labels: leaf labels in the order they are going to appear on the plot (e) P['leaves']: left-to-right traversal of the leaves """ d = distfun(X) Z = linkagefun(d) P = sch.dendrogram(Z, orientation=self.orientation, labels=self.labels, no_plot=True) icoord = scp.array(P['icoord']) dcoord = scp.array(P['dcoord']) ordered_labels = scp.array(P['ivl']) color_list = scp.array(P['color_list']) colors = self.get_color_dict(colorscale) trace_list = [] for i in range(len(icoord)): # xs and ys are arrays of 4 points that make up the '∩' shapes # of the dendrogram tree if self.orientation in ['top', 'bottom']: xs = icoord[i] else: xs = dcoord[i] if self.orientation in ['top', 'bottom']: ys = dcoord[i] else: ys = icoord[i] color_key = color_list[i] text_annotation = None if annotation: text_annotation = annotation[i] trace = graph_objs.Scatter( x=np.multiply(self.sign[self.xaxis], xs), y=np.multiply(self.sign[self.yaxis], ys), mode='lines', marker=graph_objs.Marker(color=colors[color_key]), text=text_annotation, hoverinfo='text' ) try: x_index = int(self.xaxis[-1]) except ValueError: x_index = '' try: y_index = int(self.yaxis[-1]) except ValueError: y_index = '' trace['xaxis'] = 'x' + x_index trace['yaxis'] = 'y' + y_index trace_list.append(trace) return trace_list, icoord, dcoord, ordered_labels, P['leaves'] # In[7]: from gensim.matutils import jensen_shannon from scipy import spatial as scs from scipy.cluster import hierarchy as sch from scipy.spatial.distance import pdist, squareform # get topic distributions topic_dist = lda_fake.state.get_lambda() # get topic terms num_words = 300 topic_terms = [{w for (w, _) in lda_fake.show_topic(topic, topn=num_words)} for topic in range(topic_dist.shape[0])] # no. of terms to display in annotation n_ann_terms = 10 # use Jensen-Shannon distance metric in dendrogram def js_dist(X): return pdist(X, lambda u, v: jensen_shannon(u, v)) # calculate text annotations def text_annotation(topic_dist, topic_terms, n_ann_terms): # get dendrogram hierarchy data linkagefun = lambda x: sch.linkage(x, 'single') d = js_dist(topic_dist) Z = linkagefun(d) P = sch.dendrogram(Z, orientation="bottom", no_plot=True) # store topic no.(leaves) corresponding to the x-ticks in dendrogram x_ticks = np.arange(5, len(P['leaves']) * 10 + 5, 10) x_topic = dict(zip(P['leaves'], x_ticks)) # store {topic no.:topic terms} topic_vals = dict() for key, val in x_topic.items(): topic_vals[val] = (topic_terms[key], topic_terms[key]) text_annotations = [] # loop through every trace (scatter plot) in dendrogram for trace in P['icoord']: fst_topic = topic_vals[trace[0]] scnd_topic = topic_vals[trace[2]] # annotation for two ends of current trace pos_tokens_t1 = list(fst_topic[0])[:min(len(fst_topic[0]), n_ann_terms)] neg_tokens_t1 = list(fst_topic[1])[:min(len(fst_topic[1]), n_ann_terms)] pos_tokens_t4 = list(scnd_topic[0])[:min(len(scnd_topic[0]), n_ann_terms)] neg_tokens_t4 = list(scnd_topic[1])[:min(len(scnd_topic[1]), n_ann_terms)] t1 = "
".join((": ".join(("+++", str(pos_tokens_t1))), ": ".join(("---", str(neg_tokens_t1))))) t2 = t3 = () t4 = "
".join((": ".join(("+++", str(pos_tokens_t4))), ": ".join(("---", str(neg_tokens_t4))))) # show topic terms in leaves if trace[0] in x_ticks: t1 = str(list(topic_vals[trace[0]][0])[:n_ann_terms]) if trace[2] in x_ticks: t4 = str(list(topic_vals[trace[2]][0])[:n_ann_terms]) text_annotations.append([t1, t2, t3, t4]) # calculate intersecting/diff for upper level intersecting = fst_topic[0] & scnd_topic[0] different = fst_topic[0].symmetric_difference(scnd_topic[0]) center = (trace[0] + trace[2]) / 2 topic_vals[center] = (intersecting, different) # remove trace value after it is annotated topic_vals.pop(trace[0], None) topic_vals.pop(trace[2], None) return text_annotations # In[8]: # get text annotations annotation = text_annotation(topic_dist, topic_terms, n_ann_terms) # Plot dendrogram dendro = create_dendrogram(topic_dist, distfun=js_dist, labels=range(1, 36), annotation=annotation) dendro['layout'].update({'width': 1000, 'height': 600}) py.iplot(dendro) # The x-axis or the leaves of hierarchy represent the topics of our LDA model, y-axis is a measure of closeness of either individual topics or their cluster. Essentially, the y-axis level at which the branches merge (relative to the "root" of the tree) is related to their similarity. For ex., topic 4 and 30 are more similar to each other than to topic 32. In addition, topic 18 and 24 are more similar to 35 than topic 4 and 30 are to topic 32 as the height on which they merge is lower than the merge height of 4/30 to 32. # # Text annotations visible on hovering over the cluster nodes show the intersecting/different terms of it's two child nodes. Cluster node on first hierarchy level uses the topics on leaves directly to calculate intersecting/different terms, and the upper nodes assume the intersection(+++) as the topic terms of it's child node. # # This type of tree graph could help us see the high level cluster theme that might exist in our data as we can see the common/different terms of combined topics in a cluster head annotation. # ## Dendrogram with a Heatmap # # Now lets append the distance matrix of the topics below the dendrogram in form of heatmap so that we can see the exact distances between all pair of topics. # In[9]: # get text annotations annotation = text_annotation(topic_dist, topic_terms, n_ann_terms) # Initialize figure by creating upper dendrogram figure = create_dendrogram(topic_dist, distfun=js_dist, labels=range(1, 36), annotation=annotation) for i in range(len(figure['data'])): figure['data'][i]['yaxis'] = 'y2' # In[17]: # get distance matrix and it's topic annotations mdiff, annotation = lda_fake.diff(lda_fake, distance="jensen_shannon", normed=False) # get reordered topic list dendro_leaves = figure['layout']['xaxis']['ticktext'] dendro_leaves = [x - 1 for x in dendro_leaves] # reorder distance matrix heat_data = mdiff[dendro_leaves, :] heat_data = heat_data[:, dendro_leaves] # In[18]: # heatmap annotation annotation_html = [["+++ {}
--- {}".format(", ".join(int_tokens), ", ".join(diff_tokens)) for (int_tokens, diff_tokens) in row] for row in annotation] # plot heatmap of distance matrix heatmap = go.Data([ go.Heatmap( z=heat_data, colorscale='YIGnBu', text=annotation_html, hoverinfo='x+y+z+text' ) ]) heatmap[0]['x'] = figure['layout']['xaxis']['tickvals'] heatmap[0]['y'] = figure['layout']['xaxis']['tickvals'] # Add Heatmap Data to Figure figure['data'].extend(heatmap) dendro_leaves = [x + 1 for x in dendro_leaves] # Edit Layout figure['layout'].update({'width': 800, 'height': 800, 'showlegend':False, 'hovermode': 'closest', }) # Edit xaxis figure['layout']['xaxis'].update({'domain': [.25, 1], 'mirror': False, 'showgrid': False, 'showline': False, "showticklabels": True, "tickmode": "array", "ticktext": dendro_leaves, "tickvals": figure['layout']['xaxis']['tickvals'], 'zeroline': False, 'ticks': ""}) # Edit yaxis figure['layout']['yaxis'].update({'domain': [0, 0.75], 'mirror': False, 'showgrid': False, 'showline': False, "showticklabels": True, "tickmode": "array", "ticktext": dendro_leaves, "tickvals": figure['layout']['xaxis']['tickvals'], 'zeroline': False, 'ticks': ""}) # Edit yaxis2 figure['layout'].update({'yaxis2':{'domain': [0.75, 1], 'mirror': False, 'showgrid': False, 'showline': False, 'zeroline': False, 'showticklabels': False, 'ticks': ""}}) py.iplot(figure) # The heatmap lets us see the exact distance measure between any two topics in the z-value of their corresponding cell and also their intersecting or different terms in the +++/--- annotation. This could help see the distance between those topics also which are not directly connected in the dendrogram.