import multicelltypist
from datetime import date
import hisepy
import numpy as np
import os
import pandas as pd
import scanpy as sc
def read_adata_uuid(h5ad_uuid):
h5ad_path = '/home/jupyter/cache/{u}'.format(u = h5ad_uuid)
if not os.path.isdir(h5ad_path):
hise_res = hisepy.reader.cache_files([h5ad_uuid])
h5ad_filename = os.listdir(h5ad_path)[0]
h5ad_file = '{p}/{f}'.format(p = h5ad_path, f = h5ad_filename)
adata = sc.read_h5ad(h5ad_file)
return adata
def resample_anndata_min_max(adata, label_column, max_cells=None, min_cells=None, random_state = 3030):
"""
Resamples an AnnData object based on the cell labels, with the option to resample with
replacement for classes below a specified threshold.
Parameters:
ad (AnnData): The AnnData object to be resampled.
label_column (str): The column in ad.obs where the labels are stored.
max_cells (int, optional): The maximum number of cells to keep per label. If None, no upper limit is applied.
min_cells (int, optional): The minimum number of cells below which resampling with replacement occurs. If None, no lower limit is applied.
random_state (int, default = 3030): An integer used to set the state of the numpy.random.Generator
Returns:
AnnData: The resampled AnnData object.
"""
labels = adata.obs[label_column].unique()
subsets = []
rng = np.random.default_rng(random_state)
for label in labels:
# Subset AnnData object for the current label
subset = adata.obs[adata.obs[label_column] == label]
# Resample with replacement if the number of cells is below min_cells and min_cells is defined
if min_cells is not None and subset.shape[0] < min_cells:
subset = subset.sample(min_cells, replace = True, random_state = rng)
# Resample without replacement if the number of cells is greater than max_cells and max_cells is defined
elif max_cells is not None and subset.shape[0] > max_cells:
subset = subset.sample(max_cells, replace = False, random_state = rng)
subsets.append(subset)
# Concatenate all subsets
resampled_obs = pd.concat(subsets)
resampled_adata = adata[resampled_obs.index]
resampled_adata.obs_names_make_unique()
return resampled_adata
label_column = 'AIFI_L2'
max_cell_number = 80000
h5ad_uuid = '6e8972a5-9463-4230-84b4-a20de055b9c3'
adata = read_adata_uuid(h5ad_uuid)
adata.shape
(1823666, 1261)
adata.obs[label_column].value_counts()
AIFI_L2 Naive CD4 T cell 378071 Memory CD4 T cell 321788 CD14 monocyte 269328 Memory CD8 T cell 183096 CD56dim NK cell 133881 Naive CD8 T cell 121167 Naive B cell 86711 gdT 50587 MAIT 48084 Memory B cell 47886 CD16 monocyte 45920 Treg 39087 cDC2 14235 Intermediate monocyte 12671 Transitional B cell 12555 Effector B cell 11329 CD56bright NK cell 11055 Platelet 7903 pDC 7587 CD8aa 5737 Proliferating NK cell 2825 DN T cell 2349 Proliferating T cell 2320 Plasma cell 2151 Progenitor cell 1526 Erythrocyte 1508 cDC1 943 ILC 844 ASDC 522 Name: count, dtype: int64
adata_subset = resample_anndata_min_max(
adata,
label_column,
max_cells = max_cell_number,
random_state = 3030
)
adata_subset.shape
(889624, 1261)
adata_subset.obs[label_column].value_counts()
AIFI_L2 CD14 monocyte 80000 CD56dim NK cell 80000 Naive B cell 80000 Memory CD8 T cell 80000 Memory CD4 T cell 80000 Naive CD4 T cell 80000 Naive CD8 T cell 80000 gdT 50587 MAIT 48084 Memory B cell 47886 CD16 monocyte 45920 Treg 39087 cDC2 14235 Intermediate monocyte 12671 Transitional B cell 12555 Effector B cell 11329 CD56bright NK cell 11055 Platelet 7903 pDC 7587 CD8aa 5737 Proliferating NK cell 2825 DN T cell 2349 Proliferating T cell 2320 Plasma cell 2151 Progenitor cell 1526 Erythrocyte 1508 cDC1 943 ILC 844 ASDC 522 Name: count, dtype: int64
adata_subset = adata_subset.raw.to_adata()
adata_subset.shape
(889624, 33538)
sc.pp.normalize_total(adata_subset, target_sum=1e4)
sc.pp.log1p(adata_subset)
WARNING: adata.X seems to be already log-transformed.
model_fs = multicelltypist.train(
adata_subset,
label_column,
n_jobs = 60,
max_iter = 10,
multi_class = 'ovr',
use_SGD = True
)
🍳 Preparing data before training ✂️ 4486 non-expressed genes are filtered out 🔬 Input data has 889624 cells and 29052 genes ⚖️ Scaling input data 🏋️ Training data using SGD logistic regression ⚠️ Warning: it may take a long time to train this dataset with 889624 cells and 29052 genes, try to downsample cells and/or restrict genes to a subset (e.g., hvgs) ✅ Model training done!
Detected genes:
df = adata_subset.X.toarray()
flag = df.sum(axis = 0) == 0
gene = adata_subset.var_names[ ~flag]
Features with high absolute classifier coefficients for each cell type class
np.argpartition
will take the coefficient scores for each class, and retrieve the positions of the highest absolute coefficient scores to the end of an array of positions. We then select the top_n
positions from the end of our array of positions, which allow us to retrieve genes with the highest absolute coefficients for each class.
We can then combine these to get a unique list of genes that are important for our model.
top_n = 200
gene_index = np.argpartition(
np.abs(model_fs.classifier.coef_),
-top_n,
axis = 1
)
gene_index = gene_index[:, -top_n:]
gene_index = np.unique(gene_index)
print('Number of genes selected: {n}'.format(n = len(gene_index)))
Number of genes selected: 1931
selected_genes = gene[gene_index.tolist()]
selected_df = pd.DataFrame({'gene': selected_genes})
selected_df.head()
gene | |
---|---|
0 | HES4 |
1 | ISG15 |
2 | TTLL10 |
3 | TNFRSF18 |
4 | TNFRSF4 |
adata = adata.raw.to_adata()
sc.pp.normalize_total(adata, target_sum=1e4)
sc.pp.log1p(adata)
WARNING: adata.X seems to be already log-transformed.
adata = adata[:, adata.var_names.isin(selected_genes)]
adata.shape
(1823666, 1931)
model_fs = multicelltypist.train(
adata,
label_column,
n_jobs = 60,
max_iter = 100,
multi_class = 'ovr',
check_expression = False
)
🍳 Preparing data before training 🔬 Input data has 1823666 cells and 1931 genes ⚖️ Scaling input data 🏋️ Training data using logistic regression ✅ Model training done!
out_dir = 'output'
if not os.path.isdir(out_dir):
os.makedirs(out_dir)
out_genes = 'output/ref_pbmc_clean_celltypist_top{n}_features_{l}_{d}.csv'.format(
n = top_n,
l = label_column,
d = date.today()
)
selected_df.to_csv(out_genes)
out_model = 'output/ref_pbmc_clean_celltypist_model_{l}_{d}.pkl'.format(
l = label_column,
d = date.today()
)
model_fs.write(out_model)
Finally, we'll use hisepy.upload.upload_files()
to send a copy of our output to HISE to use for downstream analysis steps.
study_space_uuid = '64097865-486d-43b3-8f94-74994e0a72e0'
title = 'PBMC Reference {l} CellTypist Model {d}'.format(
l = label_column,
d = date.today()
)
in_files = [h5ad_uuid]
in_files
['6e8972a5-9463-4230-84b4-a20de055b9c3']
out_files = [out_genes, out_model]
out_files
['output/ref_pbmc_clean_celltypist_top200_features_AIFI_L2_2024-03-10.csv', 'output/ref_pbmc_clean_celltypist_model_AIFI_L2_2024-03-10.pkl']
hisepy.upload.upload_files(
files = out_files,
study_space_id = study_space_uuid,
title = title,
input_file_ids = in_files
)
output/ref_pbmc_clean_celltypist_top200_features_AIFI_L2_2024-03-10.csv output/ref_pbmc_clean_celltypist_model_AIFI_L2_2024-03-10.pkl Cannot determine the current notebook. 1) /home/jupyter/scRNA-Reference-IH-A/06-Modeling/30-Python_celltypist_L2_model.ipynb 2) /home/jupyter/scRNA-Reference-IH-A/06-Modeling/31-Python_celltypist_L3_model.ipynb 3) /home/jupyter/scRNA-Reference-IH-A/06-Modeling/29-Python_celltypist_L1_model.ipynb Please select (1-3)
you are trying to upload file_ids... ['output/ref_pbmc_clean_celltypist_top200_features_AIFI_L2_2024-03-10.csv', 'output/ref_pbmc_clean_celltypist_model_AIFI_L2_2024-03-10.pkl']. Do you truly want to proceed?
{'trace_id': 'da03c2e2-b965-48ab-89a6-89b653f95b7d', 'files': ['output/ref_pbmc_clean_celltypist_top200_features_AIFI_L2_2024-03-10.csv', 'output/ref_pbmc_clean_celltypist_model_AIFI_L2_2024-03-10.pkl']}
import session_info
session_info.show()
----- anndata 0.10.3 hisepy 0.3.0 multicelltypist 1.6.2 numpy 1.25.2 pandas 2.1.4 scanpy 1.9.6 session_info 1.0.0 -----
PIL 10.0.1 anyio NA arrow 1.3.0 asttokens NA attr 23.2.0 attrs 23.2.0 babel 2.14.0 beatrix_jupyterlab NA brotli NA cachetools 5.3.1 certifi 2024.02.02 cffi 1.16.0 charset_normalizer 3.3.2 cloudpickle 2.2.1 colorama 0.4.6 comm 0.1.4 cryptography 41.0.7 cycler 0.10.0 cython_runtime NA dateutil 2.8.2 db_dtypes 1.1.1 debugpy 1.8.0 decorator 5.1.1 defusedxml 0.7.1 deprecated 1.2.14 exceptiongroup 1.2.0 executing 2.0.1 fastjsonschema NA fqdn NA google NA greenlet 2.0.2 grpc 1.58.0 grpc_status NA h5py 3.10.0 idna 3.6 igraph 0.10.8 importlib_metadata NA ipykernel 6.28.0 ipython_genutils 0.2.0 ipywidgets 8.1.1 isoduration NA jedi 0.19.1 jinja2 3.1.2 joblib 1.3.2 json5 NA jsonpointer 2.4 jsonschema 4.20.0 jsonschema_specifications NA jupyter_events 0.9.0 jupyter_server 2.12.1 jupyterlab_server 2.25.2 jwt 2.8.0 kiwisolver 1.4.5 leidenalg 0.10.1 llvmlite 0.41.0 lz4 4.3.2 markupsafe 2.1.3 matplotlib 3.8.0 matplotlib_inline 0.1.6 mpl_toolkits NA mpmath 1.3.0 natsort 8.4.0 nbformat 5.9.2 numba 0.58.0 opentelemetry NA overrides NA packaging 23.2 parso 0.8.3 pexpect 4.8.0 pickleshare 0.7.5 pkg_resources NA platformdirs 4.1.0 plotly 5.18.0 prettytable 3.9.0 prometheus_client NA prompt_toolkit 3.0.42 proto NA psutil NA ptyprocess 0.7.0 pure_eval 0.2.2 pyarrow 13.0.0 pydev_ipython NA pydevconsole NA pydevd 2.9.5 pydevd_file_utils NA pydevd_plugins NA pydevd_tracing NA pygments 2.17.2 pynvml NA pyparsing 3.1.1 pyreadr 0.5.0 pythonjsonlogger NA pytz 2023.3.post1 referencing NA requests 2.31.0 rfc3339_validator 0.1.4 rfc3986_validator 0.1.1 rpds NA scipy 1.11.4 send2trash NA shapely 1.8.5.post1 six 1.16.0 sklearn 1.3.2 sniffio 1.3.0 socks 1.7.1 sql NA sqlalchemy 2.0.21 sqlparse 0.4.4 stack_data 0.6.2 sympy 1.12 termcolor NA texttable 1.7.0 threadpoolctl 3.2.0 torch 2.1.2+cu121 torchgen NA tornado 6.3.3 tqdm 4.66.1 traitlets 5.9.0 typing_extensions NA uri_template NA urllib3 1.26.18 wcwidth 0.2.12 webcolors 1.13 websocket 1.7.0 wrapt 1.15.0 xarray 2023.12.0 yaml 6.0.1 zipp NA zmq 25.1.2 zoneinfo NA zstandard 0.22.0
----- IPython 8.19.0 jupyter_client 8.6.0 jupyter_core 5.6.1 jupyterlab 4.1.2 notebook 6.5.4 ----- Python 3.10.13 | packaged by conda-forge | (main, Dec 23 2023, 15:36:39) [GCC 12.3.0] Linux-5.15.0-1042-gcp-x86_64-with-glibc2.31 ----- Session information updated at 2024-03-10 03:22