This demo shows how to create an interactive visualisation of clusters.
The first few steps use the model trained in the deduplication quickstart example
import altair as alt
import pandas as pd
from utility_functions.demo_utils import get_spark
spark = get_spark() # See utility_functions/demo_utils.py for how to set up Spark
df = spark.read.parquet("data/fake_1000.parquet")
22/01/11 05:47:55 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable Using Spark's default log4j profile: org/apache/spark/log4j-defaults.properties Setting default log level to "WARN". To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel). 22/01/11 05:47:56 WARN Utils: Service 'SparkUI' could not bind on port 4040. Attempting port 4041. 22/01/11 05:47:56 WARN Utils: Service 'SparkUI' could not bind on port 4041. Attempting port 4042. 22/01/11 05:47:56 WARN Utils: Service 'SparkUI' could not bind on port 4042. Attempting port 4043. 22/01/11 05:47:56 WARN Utils: Service 'SparkUI' could not bind on port 4043. Attempting port 4044. 22/01/11 05:47:56 WARN Utils: Service 'SparkUI' could not bind on port 4044. Attempting port 4045. 22/01/11 05:47:56 WARN Utils: Service 'SparkUI' could not bind on port 4045. Attempting port 4046.
from splink import Splink
settings = {
"link_type": "dedupe_only",
"blocking_rules": ["l.surname = r.surname",
"l.first_name = r.first_name",
"l.dob = r.dob",
"l.email = r.email",
],
"comparison_columns": [
{
"col_name": "first_name",
"num_levels": 3,
"term_frequency_adjustments": True,
"m_probabilities": [
0.3941434323787689,
0.14060422778129578,
0.4652523398399353,
],
"u_probabilities": [
0.9941955208778381,
0.0028420439921319485,
0.002962463302537799,
],
},
{
"col_name": "surname",
"num_levels": 3,
"term_frequency_adjustments": True,
"m_probabilities": [
0.3971782326698303,
0.11397389322519302,
0.48884785175323486,
],
"u_probabilities": [
0.9930331110954285,
0.00222682929597795,
0.004740049596875906,
],
},
{
"col_name": "dob",
"m_probabilities": [0.38818904757499695, 0.6118109226226807],
"u_probabilities": [0.9997655749320984, 0.00023440067889168859],
},
{
"col_name": "city",
"case_expression": "case\n when city_l is null or city_r is null then -1\n when city_l = city_r then 1\n else 0 end as gamma_city",
"m_probabilities": [0.29216697812080383, 0.7078329920768738],
"u_probabilities": [0.9105007648468018, 0.08949924260377884],
},
{
"col_name": "email",
"m_probabilities": [0.32461094856262207, 0.6753890514373779],
"u_probabilities": [0.999818742275238, 0.00018127892690245062],
},
],
"additional_columns_to_retain": ["group"],
"proportion_of_matches": 0.005672720726579428,
}
linker = Splink(settings, df, spark)
df_e = linker.manually_apply_fellegi_sunter_weights()
22/01/11 05:48:07 WARN package: Truncated the string representation of a plan since it was too large. This behavior can be adjusted by setting 'spark.sql.debug.maxToStringFields'.
# Cluster at a probability threshold of 50%
from splink.cluster import clusters_at_thresholds
nodes_with_clusters = clusters_at_thresholds(df, df_e, {'cluster_low':0.75}, linker.model)
nodes_with_clusters.toPandas().sort_values("cluster_low").head()
cluster_low | unique_id | first_name | surname | dob | city | group | ||
---|---|---|---|---|---|---|---|---|
0 | 0 | 2 | Julia | Taylor | 2016-01-27 | London | hannah88@powers.com | 0 |
387 | 0 | 3 | Julia | Taylor | 2015-10-29 | None | hannah88opowersc@m | 0 |
644 | 0 | 0 | Julia | None | 2015-10-29 | London | hannah88@powers.com | 0 |
645 | 0 | 1 | Julia | Taylor | 2015-07-31 | London | hannah88@powers.com | 0 |
388 | 4 | 7 | Noah | Watson | 2008-02-05 | tolon | matthew78@ballard-mcdonald.net | 1 |
# Get a few of the largest clusters
nodes_with_clusters.createOrReplaceTempView("nodes_with_clusters")
sql = """
select count(*) as count, cluster_low
from nodes_with_clusters
group by cluster_low
order by count(*) desc
limit 10
"""
largest_clusters = spark.sql(sql).toPandas().head(10)
display(largest_clusters.head(3))
cluster_ids = list(largest_clusters["cluster_low"])
count | cluster_low | |
---|---|---|
0 | 16 | 394 |
1 | 10 | 804 |
2 | 10 | 654 |
cluster_ids
[394, 804, 654, 279, 517, 301, 664, 105, 581, 194]
The visualisation needs a list of edges and nodes. splink_cluster_studio
contains functions to create and format these tables ready for input into the vis
from splink_cluster_studio import (
get_edges_corresponding_to_clusters_from_spark,
get_nodes_corresponding_to_clusters_from_spark,
)
nodes_for_vis_pd = get_nodes_corresponding_to_clusters_from_spark(
nodes_with_clusters, "cluster_low", cluster_ids
)
edges_for_vis_pd = get_edges_corresponding_to_clusters_from_spark(
nodes_with_clusters, df_e, "cluster_low", cluster_ids
)
Optionally, we can compute graph metrics, which will then be displayed in the vis.
If we have ground truth clusters, this information will also be displayed in the vis
from splink_cluster_studio import compute_node_metrics, compute_edge_metrics, compute_cluster_metrics
nodes_for_vis_pd = compute_node_metrics(nodes_for_vis_pd, edges_for_vis_pd, "cluster_low", ground_truth_cluster_colname="group")
edges_for_vis_pd = compute_edge_metrics( edges_for_vis_pd, "cluster_low", ground_truth_cluster_colname="group")
clusters_for_vis_pd = compute_cluster_metrics(edges_for_vis_pd, "cluster_low")
The vis is rendered to a file, which you can load in your browser or dislplay in an iframe in Jupyter
from splink_cluster_studio import render_html_vis
splink_settings_dict = linker.model.current_settings_obj.settings_dict
render_html_vis(
nodes_for_vis_pd, edges_for_vis_pd, splink_settings_dict, "interactive_clusters.html", "cluster_low", overwrite=True, df_cluster_metrics=clusters_for_vis_pd,)
# Show outputted html file in iframe in Juptyer
from IPython.display import IFrame
IFrame(
src="./interactive_clusters.html", width=1400, height=1200
) # Show outputted html file in iframe in Juptyer