Confusion Matrix as Sankey Diagram
In machine learning a confusion matrix is a kind of a table that is used to understand how well our classification model predictions perform, typically a supervised learning. It helps us a lot in understanding the model behavior and interpreting the results.
Usually, a confusion matrix is displayed as raw numbers in an array. Very often we visualize a confusion matrix by plotting it as a heatmap. But there is also another, more elegant and interactive way. In this notebook we will describe step-by-step the process of creating an interactive Sankey confusion matrix using Plotly.
These are the main features of a Sankey confusion matrix:
import numpy as np
import pandas as pd
pd.set_option('display.max_colwidth', None)
import os
# Classification metrics
from sklearn.metrics import confusion_matrix
Plotly dependencies
from plotly import graph_objects as go
# Set the appropriate renderer in Jupyter Lab to allow Plotly displays figure correctly
# Set the default renderer explicitly as iframe
import plotly.io as pio
pio.renderers.default = 'iframe'
# If multiple notebooks are using 'iframe', set different 'html_directory' for each notebook
iframe_renderer = pio.renderers['iframe']
iframe_renderer.html_directory='iframe_figures_n1'
To help us with visualizations, we will import the module metrics_utilities.py
. It is a collection of several helper functions for confusion matrix visualization.
The function developed in this notebook will be added to this module.
# Import the script from different folder
import sys
sys.path.append('./scripts')
import metrics_utilities as mu
The data used in this notebook is a result from one of my previous projects - Bank-Churn-Prediction.
I saved the true (actual) labels and predictions in .npy
format and in the next two cells we will load them.
# True labels
y_test = np.load('./data/y_test.npy')
# Prepare predictions for our models
pred_dt = np.load('./data/pred_dt.npy')
pred_dl = np.load('./data/pred_dl.npy')
pred_knn = np.load('./data/pred_knn.npy')
pred_lr = np.load('./data/pred_lr.npy')
pred_rf = np.load('./data/pred_rf.npy')
pred_svm = np.load('./data/pred_svm.npy')
pred_xgb = np.load('./data/pred_xgb.npy')
The target_names
variable holds names of our classes. It will be used later for displaying evaluation results.
# Names of our classes
target_names = ['Stays', 'Exits']
In the literature, we can find two variants for representing the samples in a confusion matrix:
In this notebook we will use the first variant, where actual labels are on the horizontal axes and predicted labels on the vertical axes. Let us consider a binary classification problem with two classes: 0
(Negative) and 1
(Positive). The confusion matrix would be:
As real-life data is frequently imbalanced, utilizing a confusion matrix without normalization can potentially lead to misleading or incorrect conclusions.
In our Sankey diagram we will include values from both, unnormalized and normalized, matrices.
The simplest way to display a confusion matrix is as raw numbers in an array.
Unnormalized Confusion Matrix
# Confusion matrix
cm = confusion_matrix(y_test, pred_dt)
print(cm)
[[1979 410] [ 198 413]]
Normalized Confusion Matrix
Few points to know:
# Normalized confusion matrix
cmn = np.around(cm / cm.sum(axis=1)[:, np.newaxis], 2)
print(cmn)
[[0.83 0.17] [0.32 0.68]]
A Sankey Diagram is a visual tool used to illustrate the transfer of energy, money, materials, or the flow of any isolated system or process.
It provides a clear depiction of flows and their quantities, showcasing the proportion of values transferring from one set to another. In a Sankey Diagram, the interconnected elements are referred to as nodes, and the connections between them (flows) are known as links.
To understand Sankey diagrams, it is fundamental to become familiar with the key terminology:
To create a Sankey diagram, first we have to organize our data. We will use Pandas DataFrame to prepare and store the data.
We wiil split the process in several steps:
Create the DataFrame from the confusion matrix, using previously defined target_names
as row and column names.
# Create dataframe
df = pd.DataFrame(cm, columns=target_names, index=target_names)
df
Stays | Exits | |
---|---|---|
Stays | 1979 | 410 |
Exits | 198 | 413 |
A Sankey diagram requires three data columns — one for the "From" column (source nodes), one for the "To" column (target nodes), and one for the values (flow quantity) corresponding to each pairing (link).
We need to transform this base dataframe to the following dataframe:
actual predicted samples
0 ACTUAL Stays PREDICTED Stays 1979
1 ACTUAL Stays PREDICTED Exits 410
2 ACTUAL Exits PREDICTED Stays 198
3 ACTUAL Exits PREDICTED Exits 413
For our Sankey diagram
actual
represents Sankey source nodespredicted
represents Sankey target nodessamples
represents flow quantities for Sankey links.Later we will add more columns to improve interpretability of our Sankey confusion matrix.
Let's name the row axis to ACTUAL and the column axis to PREDICTED.
# Axes naming
df = df.rename_axis(index='ACTUAL', columns='PREDICTED')
df
PREDICTED | Stays | Exits |
---|---|---|
ACTUAL | ||
Stays | 1979 | 410 |
Exits | 198 | 413 |
Let's append axes names to labels of rows and columns
[f'ACTUAL {s}' for s in target_names]
['ACTUAL Stays', 'ACTUAL Exits']
# Set new labels for rows and columns
df = df.set_axis([f'ACTUAL {s}' for s in target_names], axis=0)
df = df.set_axis([f'PREDICTED {s}' for s in target_names], axis=1)
print(df)
PREDICTED Stays PREDICTED Exits ACTUAL Stays 1979 410 ACTUAL Exits 198 413
We can do the same in one line.
We will create a dataframe from the confusion matrix using the new labels for rows and columns.
df = pd.DataFrame(cm, columns=[f'PREDICTED {s}' for s in target_names], index=[f'ACTUAL {s}' for s in target_names])
print(df)
PREDICTED Stays PREDICTED Exits ACTUAL Stays 1979 410 ACTUAL Exits 198 413
IMPORTANT: Sankeys only take integers for node and target values.
We will do this transformation a little bit later.
And for now let's prepare data for that.
First we will create a list of node labels and then a dictionary of their indices.
# column labels --> Sankey target nodes
cl = df.columns.values.tolist()
cl
['PREDICTED Stays', 'PREDICTED Exits']
# row labels --> Sankey source nodes
rl = df.index.values.tolist()
rl
['ACTUAL Stays', 'ACTUAL Exits']
node_labels = rl + cl
node_labels
['ACTUAL Stays', 'ACTUAL Exits', 'PREDICTED Stays', 'PREDICTED Exits']
# Create dictionary with node labels indices
node_labels_inds = {label:ind for ind, label in enumerate(node_labels)}
node_labels_inds
{'ACTUAL Stays': 0, 'ACTUAL Exits': 1, 'PREDICTED Stays': 2, 'PREDICTED Exits': 3}
For Sankey diagram we need to plot flows from source nodes to target nodes. The flows are the numbers of samples being correctly or incorrectly classified.
The new reshaped dataframe will have 2x2=4 rows, 4 combinations. Each row is one flow:
ACTUAL Stays → PREDICTED Stays = # of Stays correctly classified
ACTUAL Stays → PREDICTED Exits = # of Stays incorrectly classified
ACTUAL Exits → PREDICTED Stays = # of Exits incorrectly classified
ACTUAL Exits → PREDICTED Exits = # of Exits correctly classified
To acomplish this we will use Pandas funcitions:
# Reshape dataframe
df = df.stack().reset_index()
df.rename(columns={0:'samples', 'level_0':'actual', 'level_1':'predicted'}, inplace=True)
df
actual | predicted | samples | |
---|---|---|---|
0 | ACTUAL Stays | PREDICTED Stays | 1979 |
1 | ACTUAL Stays | PREDICTED Exits | 410 |
2 | ACTUAL Exits | PREDICTED Stays | 198 |
3 | ACTUAL Exits | PREDICTED Exits | 413 |
# Normalized confusion matrix
cmn = np.around(cm / cm.sum(axis=1)[:, np.newaxis], 2)
print(cmn)
[[0.83 0.17] [0.32 0.68]]
# Flatten normmalized confusion matrix and add as a new column
df['norm_samples'] = cmn.ravel()
df
actual | predicted | samples | norm_samples | |
---|---|---|---|---|
0 | ACTUAL Stays | PREDICTED Stays | 1979 | 0.83 |
1 | ACTUAL Stays | PREDICTED Exits | 410 | 0.17 |
2 | ACTUAL Exits | PREDICTED Stays | 198 | 0.32 |
3 | ACTUAL Exits | PREDICTED Exits | 413 | 0.68 |
color
and link_hover_text
¶The ink color is determioned based on classification result (correct or incorrect)
incorrect_red = "rgba(205, 92, 92, 0.8)"
correct_green = "rgba(144, 238, 144, 0.8)"
Create a helper function to add columns color
, and link_hover_text
for text to be displayed when hovering over the Sankey links.
# 'color' - link color based on classification result (correct or incorrect)
# 'link_hover_text' - text for hovering over connecting links of sankey diagram
def new_columns(row):
source_1 = ''.join(row.actual.split()[1:])
target_1 = ''.join(row.predicted.split()[1:])
# Correct classification
if source_1 == target_1:
row['color'] = correct_green
row['link_hover_text'] = f"{row.samples} ({row.norm_samples:.0%}) {source_1} samples correctly classified as {target_1}"
# Incorrect classification
else:
row['color'] = incorrect_red
row['link_hover_text'] = f"{row.samples} ({row.norm_samples:.0%}) {source_1} samples incorrectly classified as {target_1}"
return row
Finalize the DataFrame.
# Apply heper function
df = df.apply(lambda x: new_columns(x), axis=1)
df
actual | predicted | samples | norm_samples | color | link_hover_text | |
---|---|---|---|---|---|---|
0 | ACTUAL Stays | PREDICTED Stays | 1979 | 0.83 | rgba(144, 238, 144, 0.8) | 1979 (83%) Stays samples correctly classified as Stays |
1 | ACTUAL Stays | PREDICTED Exits | 410 | 0.17 | rgba(205, 92, 92, 0.8) | 410 (17%) Stays samples incorrectly classified as Exits |
2 | ACTUAL Exits | PREDICTED Stays | 198 | 0.32 | rgba(205, 92, 92, 0.8) | 198 (32%) Exits samples incorrectly classified as Stays |
3 | ACTUAL Exits | PREDICTED Exits | 413 | 0.68 | rgba(144, 238, 144, 0.8) | 413 (68%) Exits samples correctly classified as Exits |
Map node label columns (actual
, predicted
) to integers due to Sankey requirements.
node_labels_inds
{'ACTUAL Stays': 0, 'ACTUAL Exits': 1, 'PREDICTED Stays': 2, 'PREDICTED Exits': 3}
# using replace for multiple columns
df = df.replace({'actual':node_labels_inds, 'predicted':node_labels_inds})
df
actual | predicted | samples | norm_samples | color | link_hover_text | |
---|---|---|---|---|---|---|
0 | 0 | 2 | 1979 | 0.83 | rgba(144, 238, 144, 0.8) | 1979 (83%) Stays samples correctly classified as Stays |
1 | 0 | 3 | 410 | 0.17 | rgba(205, 92, 92, 0.8) | 410 (17%) Stays samples incorrectly classified as Exits |
2 | 1 | 2 | 198 | 0.32 | rgba(205, 92, 92, 0.8) | 198 (32%) Exits samples incorrectly classified as Stays |
3 | 1 | 3 | 413 | 0.68 | rgba(144, 238, 144, 0.8) | 413 (68%) Exits samples correctly classified as Exits |
# using assign + apply + lambda
# dft.assign(actual = dft.actual.apply(lambda x: node_labels_inds[x]),
# predicted = dft.predicted.apply(lambda x: node_labels_inds[x]))
# Using assign + map
# dft.assign(actual = dft.actual.map(node_labels_indices),
# predicted = dft.predicted.map(node_labels_indices))
Prepare data for bold printing of some words in Plotly.
We want to print class names (2nd word in a string) in bold font.
We will use the HTML <b>
tag for that.
node_labels
['ACTUAL Stays', 'ACTUAL Exits', 'PREDICTED Stays', 'PREDICTED Exits']
node_labels = [f'{ls[0]} <b>{ls[1]}</b>' for ls in [l.split() for l in node_labels]]
print(node_labels)
['ACTUAL <b>Stays</b>', 'ACTUAL <b>Exits</b>', 'PREDICTED <b>Stays</b>', 'PREDICTED <b>Exits</b>']
Printing class names in bold font.
df['link_hover_text'] = [f'{" ".join(ls[0:2])} <b>{ls[2]}</b> {" ".join(ls[3:-1])} <b>{ls[-1]}</b>' for ls in [l.split() for l in df['link_hover_text']]]
df
actual | predicted | samples | norm_samples | color | link_hover_text | |
---|---|---|---|---|---|---|
0 | 0 | 2 | 1979 | 0.83 | rgba(144, 238, 144, 0.8) | 1979 (83%) <b>Stays</b> samples correctly classified as <b>Stays</b> |
1 | 0 | 3 | 410 | 0.17 | rgba(205, 92, 92, 0.8) | 410 (17%) <b>Stays</b> samples incorrectly classified as <b>Exits</b> |
2 | 1 | 2 | 198 | 0.32 | rgba(205, 92, 92, 0.8) | 198 (32%) <b>Exits</b> samples incorrectly classified as <b>Stays</b> |
3 | 1 | 3 | 413 | 0.68 | rgba(144, 238, 144, 0.8) | 413 (68%) <b>Exits</b> samples correctly classified as <b>Exits</b> |
fig = go.Figure(data=[go.Sankey(
node = dict(
pad = 30,
thickness = 20,
line = dict(color = "gray", width = 1.0),
label = node_labels,
hovertemplate = "%{label} has total %{value:d} samples<extra></extra>"
),
link = dict(
source = df.actual,
target = df.predicted,
value = df.samples,
color = df.color,
customdata = df['link_hover_text'],
hovertemplate = "%{customdata}<extra></extra>"
))])
title = f'Decision Tree'
fig.update_layout(
# hovermode = 'x',
title = {
'text': title,
'x':0.5,
},
# paper_bgcolor = '#51504f',
font_size = 15,
# font_color = 'white',
width = 600,
height = 500
)
To interpret our Sankey confusion matrix, with Stays representing the negative class and Exits representing the positive class, take note of the following key elements:
Let's now collect all these together into a function.
def plot_cm_sankey(model_name, y_test, y_pred, target_names=None):
""" Plot confusion matrix with Sankey diagram
Args:
model_name: name of the model
y_test: test target variable
y_pred: prediction
target_names: list of class names
Returns:
Plot Sankey diagram of confusion matrix
"""
# Calculate confusion matrix
cm = confusion_matrix(y_test, y_pred)
# If class labels not passed, create dummy class labels
if target_names == None:
target_names = []
if not len(target_names):
target_names = [f'class-{i+1}' for i in range(len(cm))]
# Prepare dataframe with parameters for Sankey
def prepare_df_for_sankey(cm, target_names):
# create a dataframe
df = pd.DataFrame(cm, columns=[f'PREDICTED {s}' for s in target_names], index=[f'ACTUAL {s}' for s in target_names])
# Create list of node labels
# target nodes = column labels (PREDICTED ...)
cl = df.columns.values.tolist()
# source nodes = row (index) labels (ACTUAL ...)
rl = df.index.values.tolist()
node_labels = rl + cl
# Create dictionary with indices for node labels
node_labels_inds = {label:ind for ind, label in enumerate(node_labels)}
# Stack label from column to row, output is Series
# Reset index to get DataFrame and rename columns
df = df.stack().reset_index()
df.rename(columns={0:'samples', 'level_0':'actual', 'level_1':'predicted'}, inplace=True)
"""
actual predicted samples
0 ACTUAL Stays PREDICTED Stays 1979
1 ACTUAL Stays PREDICTED Exits 410
2 ACTUAL Exits PREDICTED Stays 198
3 ACTUAL Exits PREDICTED Exits 413
"""
# Normalized confusion matrix
cmn = np.around(cm / cm.sum(axis=1)[:, np.newaxis], 2)
# Add a column with normalized values of samples
df['norm_samples'] = cmn.ravel()
# Helper function to add new columns: color and link_hover_text
# 'color' - link color based on classification result (correct or incorrect)
incorrect_red = "rgba(205, 92, 92, 0.8)"
correct_green = "rgba(144, 238, 144, 0.8)"
# # 'link_hover_text' - text for hovering on connecting links of sankey diagram
def new_columns(row):
source_1 = ''.join(row.actual.split()[1:])
target_1 = ''.join(row.predicted.split()[1:])
# Correct classification
if source_1 == target_1:
row['color'] = correct_green
row['link_hover_text'] = f"{row.samples} ({row.norm_samples:.0%}) {source_1} samples correctly classified as {target_1}"
# Incorrect classification
else:
row['color'] = incorrect_red
row['link_hover_text'] = f"{row.samples} ({row.norm_samples:.0%}) {source_1} samples incorrectly classified as {target_1}"
return row
# Apply "new_columns" function
df = df.apply(lambda x: new_columns(x), axis=1)
# Sankey only takes integers for node and target values,
# so we need to map node label columns (actual, predicted) to numbers
# Using replace for multiple columns
df = df.replace({'actual':node_labels_inds, 'predicted':node_labels_inds})
return df, node_labels
# Plotting confusion matrix as Sankey diagram
# Get dataframe and node labels
df, node_labels = prepare_df_for_sankey(cm, target_names)
# Prepare for bold printing of some words in Plotly
node_labels = [f'{ls[0]} <b>{ls[1]}</b>' for ls in [l.split() for l in node_labels]]
df['link_hover_text'] = [f'{" ".join(ls[0:2])} <b>{ls[2]}</b> {" ".join(ls[3:-1])} <b>{ls[-1]}</b>' for ls in [l.split() for l in df['link_hover_text']]]
fig = go.Figure(data=[go.Sankey(
node = dict(
pad = 50,
thickness = 30,
line = dict(color = "gray", width = 1.0),
label = node_labels,
hovertemplate = "%{label} has total %{value:d} samples<extra></extra>"
),
link = dict(
source = df.actual,
target = df.predicted,
value = df.samples,
color = df.color,
customdata = df['link_hover_text'],
hovertemplate = "%{customdata}<extra></extra>"
))])
margins = {'l': 25, 'r': 25, 't': 70, 'b': 25}
fig.update_layout(
title = {
'text': f'<b>{model_name}</b>',
'x':0.5,
},
font_size = 15,
width = 625,
height = 500,
#paper_bgcolor = '#d3d3d3',
# paper_bgcolor = 'white',
# plot_bgcolor = 'black',
margin = margins,
)
return fig
Let's test the function
plot_cm_sankey('Decision Tree', y_test, pred_dt, target_names)
Copy the function to the module metrics_utilities.py
and reload the kernel. After running all required cells above, run the following cell.
mu.plot_cm_sankey('Decision Tree', y_test, pred_dt, target_names)
Hover over the diagram to get more information about the confusion matrix.
Let's see how this works for an 3x3 confusion matrix.
We will use data from my project T2D-Predictions.
# Actual (True) labels
t2d_y_test = np.load('./data/t2d_y_test.npy')
# Prediction from random forest model
t2d_pred_rf = np.load('./data/t2d_pred_rf.npy')
# Classes
t2d_classes = ['no_diabetes', 'pre_diabetes', 'diabetes']
mu.plot_cm_sankey('Random Forest', t2d_y_test, t2d_pred_rf, t2d_classes)
Hover over the diagram to get more information about the confusion matrix.
NOTE 1:
Depending on the values in the confusion matrix, especially in the case of multi-class classification, Plotly might display the target nodes (PREDICTED labels) in a different order compared to the source nodes (ACTUAL labels).
However, rest assured that despite the potential rearrangement of nodes, all the connected links in the Sankey Diagram will accurately represent the correct width (values) as per the confusion matrix. The flow volumes will be appropriately maintained, enabling a precise representation of the classification outcomes between the actual and predicted labels.
NOTE 2 - for Jupyter Lab users:
import plotly.io as pio
pio.renderers.default = 'iframe'
html_directory
. in this notebook we named it iframe_figures_n2iframe_renderer = pio.renderers['iframe']
iframe_renderer.html_directory='iframe_figures_n1'