Illustrates the concept of the GPCCA algorithm using a toy data example.
# import standard packages
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import os
import sys
# two extra imports here, pydtmc is for Markov chains, networkx helps us with graph visualization
import pydtmc
import networkx as nx
# import single-cell packages
import scanpy as sc
import scanpy.external as sce
import scvelo as scv
import cellrank as cr
# set verbosity levels
sc.settings.verbosity = 2
cr.settings.verbosity = 2
scv.settings.verbosity = 3
If you want to exactly reproduce the results shown here, please make sure that your package versions match what is printed below.
cr.logging.print_versions()
cellrank==1.5.0+g65f1562 scanpy==1.8.1 anndata==0.7.6 numpy==1.20.3 numba==0.54.0 scipy==1.7.1 pandas==1.3.3 pygpcca==1.0.2 scikit-learn==0.24.2 statsmodels==0.13.0rc0 python-igraph==0.9.1 scvelo==0.2.4 pygam==0.8.0 matplotlib==3.4.3 seaborn==0.11.2
cr.logging.print_version_and_date()
Running CellRank 1.5.0+g65f1562, on 2021-10-26 16:00.
Define the paths to load data, cache results and write figure panels.
sys.path.insert(0, "../..") # this depends on the notebook depth and must be adapted per notebook
from paths import DATA_DIR, CACHE_DIR, FIG_DIR
We're only saving into a single directory here:
FIG_DIR = FIG_DIR / "suppl_fig_GPCCA"
Set up the paths to save figures.
scv.settings.figdir = str(FIG_DIR)
sc.settings.figdir = str(FIG_DIR)
cr.settings.figdir = str(FIG_DIR)
Set some plotting parameters.
scv.settings.set_figure_params('scvelo', dpi_save=400, dpi=80, transparent=True, fontsize=20, color_map='viridis')
scv.settings.plot_prefix = ""
Set other global parameters
# should figures just be displayed or also saved?
save_figure = True
If there are other global parameters for this analysis, put them here as well.
p = np.array([
# 0. 1. 2. 3. 4. 5. 6. 7. 8. 9. 10. 11.
[0.0, 0.8, 0.2, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], #0
[0.2, 0.0, 0.6, 0.2, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], #1
[0.6, 0.2, 0.0, 0.2, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], #2
[0.0, 0.05, 0.05, 0.0, 0.6, 0.3, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], #3
[0.0, 0.0, 0.0, 0.25, 0.0, 0.25, 0.4, 0.0, 0.0, 0.1, 0.0, 0.0], #4
[0.0, 0.0, 0.0, 0.25, 0.25, 0.0, 0.1, 0.0, 0.0, 0.4, 0.0, 0.0], #5
[0.0, 0.0, 0.0, 0.0, 0.05, 0.05, 0.0, 0.7, 0.2, 0.0, 0.0, 0.0], #6
[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.2, 0.0, 0.8, 0.0, 0.0, 0.0], #7
[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.8, 0.2, 0.0, 0.0, 0.0, 0.0], #8
[0.0, 0.0, 0.0, 0.0, 0.05, 0.05, 0.0, 0.0, 0.0, 0.0, 0.7, 0.2], #9
[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.2, 0.0, 0.8], #10
[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.8, 0.2, 0.0], #11
])
states = np.arange(p.shape[1]).astype('str')
mc = pydtmc.MarkovChain(p, states)
print(mc)
DISCRETE-TIME MARKOV CHAIN SIZE: 12 RANK: 12 CLASSES: 1 > RECURRENT: 1 > TRANSIENT: 0 ERGODIC: YES > APERIODIC: YES > IRREDUCIBLE: YES ABSORBING: NO REGULAR: NO REVERSIBLE: NO SYMMETRIC: NO
Plot the directed graph:
fig_kwargs = {'self_loops': False,
'self_loop_radius_frac': 0.05,
'labels': states,
'layout': nx.layout.spring_layout,
'color_nodes': False,
'node_size': 3000,
'edge_weight_scale': 10,
'font_size': 22,
'edge_width_limit': 10}
if save_figure: fig_kwargs['save'] = 'graph.pdf'
cr.pl.graph(p, **fig_kwargs)
Creating graph Finish (0:00:00) Embedding graph using `'spring_layout'` layout Finish (0:00:00)
# import precomputed kernel calss and initialise object
from cellrank.tl.kernels import PrecomputedKernel
pk = PrecomputedKernel(p)
# modify the underlying anndata object, provide state names and clustering
pk._adata.obs_names = states
annot = ['initial', 'initial', 'initial', 'interm', 'interm', 'interm', 'terminal_1', 'terminal_1', 'terminal_1',
'terminal_2', 'terminal_2', 'terminal_2']
pk._adata.obs['clusters'] = pd.Series(index=pk._adata.obs_names, data=annot, dtype='category')
WARNING: Creating empty `AnnData` object
Create a cellrank GPCCA object. This is a bit hacky, because we need a dummy adata object:
g_cr = cr.tl.estimators.GPCCA(pk)
Visualise the transition matrix in a heatmap, using the estimator object
fig = plt.figure(None, (9, 7))
ax = fig.add_subplot(111)
sns.heatmap(p, cmap='viridis')
if save_figure:
plt.savefig(FIG_DIR / "transition_matrix.pdf")
plt.show()
Compute schur decomposition and plot the top eigenvalues:
g_cr.compute_schur(method='krylov', n_components=12, which='LR')
fig_kwargs = {'real_only': True, 'marker': 'o', 's': 100, 'figsize': (6, 6)}
if save_figure: fig_kwargs['save'] = 'spectrum.pdf'
g_cr.plot_spectrum(**fig_kwargs)
Computing Schur decomposition When computing macrostates, choose a number of states NOT in `[7, 9, 11]` Adding `adata.uns['eigendecomposition_fwd']` `.schur_vectors` `.schur_matrix` `.eigendecomposition` Finish (0:00:00)
n_states = 4
n_cells = 3
g_cr.compute_macrostates(n_states=n_states, n_cells=n_cells, cluster_key='clusters')
Computing `4` macrostates Adding `.macrostates` `.macrostates_memberships` `.coarse_T` `.coarse_initial_distribution `.coarse_stationary_distribution` `.schur_vectors` `.schur_matrix` `.eigendecomposition` Finish (0:00:00)
Automatically determine the terminal states and show their names:
g_cr.compute_terminal_states(n_cells=n_cells, method='stability')
# check that we got the right ones
terminal_names = list(g_cr.terminal_states.cat.categories)
assert(terminal_names == ['terminal_2', 'terminal_1'])
print(f"Identified the following terminal states: {terminal_names}")
Adding `adata.obs['terminal_states']` `adata.obs['terminal_states_probs']` `.terminal_states` `.terminal_states_probabilities` `.terminal_states_memberships Finish` Identified the following terminal states: ['terminal_2', 'terminal_1']
Show assignments to macrostates
macro_names = ['initial', 'interm', 'terminal_1', 'terminal_2']
m_g = g_cr._macrostates_memberships[macro_names]
sns.heatmap(m_g, robust=True, annot=False, ax=plt.figure(None, (5, 10)).gca(), fmt='.2f',
cmap='viridis', xticklabels=m_g.names)
if save_figure:
plt.savefig(FIG_DIR / "membership_matrix.pdf")
plt.show()
Plot the coarse-grained transition matrix:
T_coarse = g_cr.coarse_T
fig = plt.figure(None, (6, 5))
ax = fig.add_subplot(111)
sns.heatmap(T_coarse[macro_names].loc[macro_names], cmap='viridis')
if save_figure:
plt.savefig(FIG_DIR / "t_coarse.pdf")
plt.show()
Compute absorption probabilities towards the two terminal states:
g_cr.compute_absorption_probabilities()
m_c = g_cr.absorption_probabilities[['terminal_1', 'terminal_2']]
sns.heatmap(m_c, robust=True, annot=False, ax=plt.figure(None, (5, 10)).gca(), fmt='.2f',
cmap='viridis', xticklabels=m_c.names)
if save_figure:
plt.savefig(FIG_DIR / "fate_probs.pdf")
plt.show()
Computing absorption probabilities
0%| | 0/2 [00:00<?, ?/s]
Adding `adata.obsm['to_terminal_states']` `.absorption_probabilities` Finish (0:00:00)