Creating a new kernel

This tutorial shows how to create your own kernel class which computes a cell-cell transition matrix. For some example kernel classes, check out CellRank's VelocityKernel or ConnectivityKernel. Contributing a new kernel class is the preferred way of interfacing to an external method that computes cell-cell transition probabilities.

Import packages & data

In [1]:
import sys

if "google.colab" in sys.modules:
    !pip install -q git+https://github.com/theislab/[email protected]
In [2]:
from typing import Any
from copy import copy
from anndata import AnnData

import cellrank as cr
import numpy as np
import scipy.sparse as sp

Import an example dataset.

In [3]:
adata = cr.datasets.pancreas()
adata
Out[3]:
AnnData object with n_obs × n_vars = 2531 × 27998
    obs: 'day', 'proliferation', 'G2M_score', 'S_score', 'phase', 'clusters_coarse', 'clusters', 'clusters_fine', 'louvain_Alpha', 'louvain_Beta', 'palantir_pseudotime'
    var: 'highly_variable_genes'
    uns: 'clusters_colors', 'clusters_fine_colors', 'day_colors', 'louvain_Alpha_colors', 'louvain_Beta_colors', 'neighbors', 'pca'
    obsm: 'X_pca', 'X_umap'
    layers: 'spliced', 'unspliced'
    obsp: 'connectivities', 'distances'

Minimal kernel

In order to create your own kernel class, you just need to do the following three things:

  • subclass from cellrank.tl.kernels.Kernel.
  • implement a .compute_transition_matrix method. This should be the core of your method - the algorithm which takes data and computes a cell-cell transition matrix from it. To save the matrix in the object, use the ._compute_transition_matrix helper method (see below). Your .compute_transition_matrix method should return the kernel class itself.
  • implement a .copy method, which returns a copy of the kernel.

The ._compute_transition_matrix helper method row-normalizes any matrix passed to it (all elements must be non-negative) and optionally computes the condition number (can be costly and only works on dense matrices).

Below you can see a minimal implementation of a kernel where the transition matrix is just a diagonal.

In [4]:
class MyKernel(cr.tl.kernels.Kernel):
    def compute_transition_matrix(self, some_parameter: float = 0.5) -> "MyKernel":
        transition_matrix = sp.diags(
            (some_parameter,) * len(self.adata), dtype=np.float64
        )
        self._compute_transition_matrix(transition_matrix, density_normalize=True)
        return self

    def copy(self) -> "MyKernel":
        return copy(self)
In [5]:
k = MyKernel(adata).compute_transition_matrix()
k.transition_matrix.A
Out[5]:
array([[1., 0., 0., ..., 0., 0., 0.],
       [0., 1., 0., ..., 0., 0., 0.],
       [0., 0., 1., ..., 0., 0., 0.],
       ...,
       [0., 0., 0., ..., 1., 0., 0.],
       [0., 0., 0., ..., 0., 1., 0.],
       [0., 0., 0., ..., 0., 0., 1.]])

Reading from AnnData

CellRank is part of the scanpy ecosystem and relies on AnnData objects to store and manipulate single cell data. The example below shows how to read data from an AnnData object though a kernel class using the ._read_from_adata method, which is invoked when initializing the class.

In [6]:
class MyKernel(cr.tl.kernels.Kernel):
    def __init__(
        self, adata: AnnData, obs_key: str = "palantir_pseudotime", **kwargs: Any
    ):
        super().__init__(adata=adata, obs_key=obs_key, **kwargs)

    def _read_from_adata(self, obs_key: str, **kwargs: Any) -> None:
        super()._read_from_adata(**kwargs)

        print(f"Reading `adata.obs[{obs_key!r}]`")
        self.pseudotime = self.adata.obs[obs_key].values

    def compute_transition_matrix(self, some_parameter: float = 0.5) -> "MyKernel":
        print("Accessing `.pseudotime`: ", self.pseudotime)
        transition_matrix = sp.diags(
            (some_parameter,) * len(self.adata), dtype=np.float64
        )

        self._compute_transition_matrix(transition_matrix)

        return self

    def copy(self) -> "MyKernel":
        return copy(self)

In the above example, we read a pseudotime from the .obs attribute of an AnnData object and store it in the kernel.

In [7]:
k = MyKernel(adata).compute_transition_matrix()
k
Reading `adata.obs['palantir_pseudotime']`
Accessing `.pseudotime`:  [0.81281052 0.81832897 0.48974318 ... 0.73317134 0.92208156 0.8219729 ]
Out[7]:
<MyKernel>

Caching values

Kernels can be combined using the elementwise operators + and *. However, this could lead to multiple evaluations of the same expression, if it's being used in multiple places in the combined expression. To resolve this problem, we use a cache the last computed transition matrix and the parameters that were used to compute it.

As a part of this caching scheme, we provide a method ._reuse_cache(parameters: Dict[str, Any]) -> bool that returns True if a cached version for the parameters is available or False otherwise. It also updates the parameters, which are accessible through the .params attribute. We demonstrate this in the example below.

In [8]:
class MyKernel(cr.tl.kernels.Kernel):
    def __init__(
        self, adata: AnnData, obs_key: str = "palantir_pseudotime", **kwargs: Any
    ):
        super().__init__(adata=adata, obs_key=obs_key, **kwargs)

    def _read_from_adata(self, obs_key: str, **kwargs: Any) -> None:
        super()._read_from_adata(**kwargs)

        print(f"Reading `adata.obs[{obs_key!r}]`")
        self.pseudotime = self.adata.obs[obs_key].values

    def compute_transition_matrix(self, some_parameter: float = 0.5) -> "MyKernel":
        if self._reuse_cache({"some_parameter": some_parameter}):
            print("Using cached values for parameters:", self.params)
            return self

        transition_matrix = sp.diags(
            (some_parameter,) * len(self.adata), dtype=np.float64
        )

        self._compute_transition_matrix(transition_matrix, density_normalize=True)

        return self

    def copy(self) -> "MyKernel":
        return copy(self)
In [9]:
k = MyKernel(adata).compute_transition_matrix(some_parameter=0.1)
k.compute_transition_matrix(some_parameter=0.1)
print(k)
Reading `adata.obs['palantir_pseudotime']`
Using cached values for parameters: {'some_parameter': 0.1}
<MyKernel[some_parameter=0.1]>

Inverting a kernel

Kernels have a direction associated with them - intuitively, a kernel can be used to compute a transition matrix for the forward or the backward process. This is most intuitive for the velocity kernel, where the backwards direction corresponds to 'flipping' the arrows. In certain situations, this can help to find the initial states of a biological process. In CellRank, the direction of a kernel can be inverted using the ~ operator. Although this is a very niche functionality, we recommend overriding the __invert__ method. That's an in-place operation which does the following:

  • it changes the direction (i.e. the attribute .backward will become True if it was False and vice-versa).
  • it invalidates the current transition matrix and the parameters that were used to compute it.

The implementation really depends on the kernel class - which data it loads from AnnData and how it uses this data to compute cell-cell transition probabilities. In our case, we just need to change the .pseudotime attribute.

In [10]:
class MyKernel(cr.tl.kernels.Kernel):
    def __init__(
        self, adata: AnnData, obs_key: str = "palantir_pseudotime", **kwargs: Any
    ):
        super().__init__(adata=adata, obs_key=obs_key, **kwargs)

    def _read_from_adata(self, obs_key: str, **kwargs: Any) -> None:
        super()._read_from_adata(**kwargs)

        print(f"Reading `adata.obs[{obs_key!r}]`")
        self.pseudotime = self.adata.obs[obs_key].values

    def compute_transition_matrix(self, some_parameter: float = 0.5) -> "MyKernel":
        if self._reuse_cache({"some_parameter": some_parameter}):
            print("Using cached values for parameters:", self.params)
            return self

        transition_matrix = sp.diags(
            (some_parameter,) * len(self.adata), dtype=np.float64
        )

        self._compute_transition_matrix(transition_matrix, density_normalize=True)

        return self

    def __invert__(self) -> "MyKernel":
        super().__invert__()
        self.pseudotime = np.max(self.pseudotime) - self.pseudotime
        return self

    def copy(self) -> "MyKernel":
        return copy(self)
In [11]:
k = MyKernel(adata)
print("Is backward?", k.backward)
k.pseudotime
Reading `adata.obs['palantir_pseudotime']`
Is backward? False
Out[11]:
array([0.81281052, 0.81832897, 0.48974318, ..., 0.73317134, 0.92208156,
       0.8219729 ])
In [12]:
k_inv = ~k
print("Is inversion and in-place operation?", k_inv is k)
print("Is backward?", k.backward)
k.pseudotime
Is inversion and in-place operation? True
Is backward? True
Out[12]:
array([0.18718948, 0.18167103, 0.51025682, ..., 0.26682866, 0.07791844,
       0.1780271 ])

Conclusion

Although CellRank's estimators can easily work with cell-cell transition matrices specified as numpy.ndrray or scipy.sparse.spmatrix, the kernel class offers other various benefits, such as:

If you're interested in contributing to CellRank, please check out our contributing guide - we're happy for any contributions.