import multicelltypist
from datetime import date
import hisepy
import numpy as np
import os
import pandas as pd
import scanpy as sc
def read_adata_uuid(h5ad_uuid):
h5ad_path = '/home/jupyter/cache/{u}'.format(u = h5ad_uuid)
if not os.path.isdir(h5ad_path):
hise_res = hisepy.reader.cache_files([h5ad_uuid])
h5ad_filename = os.listdir(h5ad_path)[0]
h5ad_file = '{p}/{f}'.format(p = h5ad_path, f = h5ad_filename)
adata = sc.read_h5ad(h5ad_file)
return adata
def resample_anndata_min_max(adata, label_column, max_cells=None, min_cells=None, random_state = 3030):
"""
Resamples an AnnData object based on the cell labels, with the option to resample with
replacement for classes below a specified threshold.
Parameters:
ad (AnnData): The AnnData object to be resampled.
label_column (str): The column in ad.obs where the labels are stored.
max_cells (int, optional): The maximum number of cells to keep per label. If None, no upper limit is applied.
min_cells (int, optional): The minimum number of cells below which resampling with replacement occurs. If None, no lower limit is applied.
random_state (int, default = 3030): An integer used to set the state of the numpy.random.Generator
Returns:
AnnData: The resampled AnnData object.
"""
labels = adata.obs[label_column].unique()
subsets = []
rng = np.random.default_rng(random_state)
for label in labels:
# Subset AnnData object for the current label
subset = adata.obs[adata.obs[label_column] == label]
# Resample with replacement if the number of cells is below min_cells and min_cells is defined
if min_cells is not None and subset.shape[0] < min_cells:
subset = subset.sample(min_cells, replace = True, random_state = rng)
# Resample without replacement if the number of cells is greater than max_cells and max_cells is defined
elif max_cells is not None and subset.shape[0] > max_cells:
subset = subset.sample(max_cells, replace = False, random_state = rng)
subsets.append(subset)
# Concatenate all subsets
resampled_obs = pd.concat(subsets)
resampled_adata = adata[resampled_obs.index]
resampled_adata.obs_names_make_unique()
return resampled_adata
label_column = 'AIFI_L3'
max_cell_number = 20000
h5ad_uuid = '6e8972a5-9463-4230-84b4-a20de055b9c3'
adata = read_adata_uuid(h5ad_uuid)
adata.shape
(1823666, 1261)
adata.obs[label_column].value_counts()
AIFI_L3 Core naive CD4 T cell 341521 Core CD14 monocyte 217576 CM CD4 T cell 161769 Core naive CD8 T cell 115126 GZMK- CD56dim NK cell 102908 ... ASDC 522 GZMK+ memory CD4 Treg 467 Activated memory B cell 433 CLP cell 373 BaEoMaP cell 78 Name: count, Length: 72, dtype: int64
adata_subset = resample_anndata_min_max(
adata,
label_column,
max_cells = max_cell_number,
random_state = 3030
)
adata_subset.shape
(636082, 1261)
adata_subset.obs[label_column].value_counts()
AIFI_L3 CM CD4 T cell 20000 CM CD8 T cell 20000 KLRF1+ GZMB+ CD27- EM CD8 T cell 20000 KLRF1- GZMB+ CD27- EM CD8 T cell 20000 Core naive CD4 T cell 20000 ... ASDC 522 GZMK+ memory CD4 Treg 467 Activated memory B cell 433 CLP cell 373 BaEoMaP cell 78 Name: count, Length: 72, dtype: int64
adata_subset = adata_subset.raw.to_adata()
adata_subset.shape
(636082, 33538)
sc.pp.normalize_total(adata_subset, target_sum=1e4)
sc.pp.log1p(adata_subset)
WARNING: adata.X seems to be already log-transformed.
model_fs = multicelltypist.train(
adata_subset,
label_column,
n_jobs = 60,
max_iter = 10,
multi_class = 'ovr',
use_SGD = True
)
🍳 Preparing data before training ✂️ 4789 non-expressed genes are filtered out 🔬 Input data has 636082 cells and 28749 genes ⚖️ Scaling input data 🏋️ Training data using SGD logistic regression ⚠️ Warning: it may take a long time to train this dataset with 636082 cells and 28749 genes, try to downsample cells and/or restrict genes to a subset (e.g., hvgs) ✅ Model training done!
Detected genes:
df = adata_subset.X.toarray()
flag = df.sum(axis = 0) == 0
gene = adata_subset.var_names[ ~flag]
Features with high absolute classifier coefficients for each cell type class
np.argpartition
will take the coefficient scores for each class, and retrieve the positions of the highest absolute coefficient scores to the end of an array of positions. We then select the top_n
positions from the end of our array of positions, which allow us to retrieve genes with the highest absolute coefficients for each class.
We can then combine these to get a unique list of genes that are important for our model.
top_n = 200
gene_index = np.argpartition(
np.abs(model_fs.classifier.coef_),
-top_n,
axis = 1
)
gene_index = gene_index[:, -top_n:]
gene_index = np.unique(gene_index)
print('Number of genes selected: {n}'.format(n = len(gene_index)))
Number of genes selected: 2507
selected_genes = gene[gene_index.tolist()]
selected_df = pd.DataFrame({'gene': selected_genes})
selected_df.head()
gene | |
---|---|
0 | HES4 |
1 | ISG15 |
2 | TTLL10 |
3 | TNFRSF18 |
4 | TNFRSF4 |
adata = adata.raw.to_adata()
sc.pp.normalize_total(adata, target_sum=1e4)
sc.pp.log1p(adata)
WARNING: adata.X seems to be already log-transformed.
adata = adata[:, adata.var_names.isin(selected_genes)]
adata.shape
(1823666, 2507)
model_fs = multicelltypist.train(
adata,
label_column,
n_jobs = 60,
max_iter = 100,
multi_class = 'multinomial',
check_expression = False
)
🍳 Preparing data before training 🔬 Input data has 1823666 cells and 2507 genes ⚖️ Scaling input data 🏋️ Training data using logistic regression
out_dir = 'output'
if not os.path.isdir(out_dir):
os.makedirs(out_dir)
out_genes = 'output/ref_pbmc_clean_celltypist_top{n}_features_{l}_{d}.csv'.format(
n = top_n,
l = label_column,
d = date.today()
)
selected_df.to_csv(out_genes)
out_model = 'output/ref_pbmc_clean_celltypist_model_{l}_{d}.pkl'.format(
l = label_column,
d = date.today()
)
model_fs.write(out_model)
Finally, we'll use hisepy.upload.upload_files()
to send a copy of our output to HISE to use for downstream analysis steps.
study_space_uuid = '64097865-486d-43b3-8f94-74994e0a72e0'
title = 'PBMC Reference {l} CellTypist Model {d}'.format(
l = label_column,
d = date.today()
)
in_files = [h5ad_uuid]
in_files
['6e8972a5-9463-4230-84b4-a20de055b9c3']
out_files = [out_genes, out_model]
out_files
['output/ref_pbmc_clean_celltypist_top200_features_AIFI_L3_2024-03-11.csv', 'output/ref_pbmc_clean_celltypist_model_AIFI_L3_2024-03-11.pkl']
hisepy.upload.upload_files(
files = out_files,
study_space_id = study_space_uuid,
title = title,
input_file_ids = in_files
)
output/ref_pbmc_clean_celltypist_top200_features_AIFI_L3_2024-03-11.csv output/ref_pbmc_clean_celltypist_model_AIFI_L3_2024-03-11.pkl Cannot determine the current notebook. 1) /home/jupyter/scRNA-Reference-IH-A/06-Modeling/31-Python_celltypist_L3_model.ipynb 2) /home/jupyter/scRNA-Reference-IH-A/05-Assembly/28-Python_clean_T_cell_projections.ipynb 3) /home/jupyter/scRNA-Reference-IH-A/05-Assembly/27-Python_clean_Other_cell_projections.ipynb Please select (1-3)
you are trying to upload file_ids... ['output/ref_pbmc_clean_celltypist_top200_features_AIFI_L3_2024-03-11.csv', 'output/ref_pbmc_clean_celltypist_model_AIFI_L3_2024-03-11.pkl']. Do you truly want to proceed?
{'trace_id': 'e3111ca7-cbfc-404d-b5c3-9c68fa313508', 'files': ['output/ref_pbmc_clean_celltypist_top200_features_AIFI_L3_2024-03-11.csv', 'output/ref_pbmc_clean_celltypist_model_AIFI_L3_2024-03-11.pkl']}
import session_info
session_info.show()
----- anndata 0.10.3 hisepy 0.3.0 multicelltypist 1.6.2 numpy 1.25.2 pandas 2.1.4 scanpy 1.9.6 session_info 1.0.0 -----
PIL 10.0.1 anyio NA arrow 1.3.0 asttokens NA attr 23.2.0 attrs 23.2.0 babel 2.14.0 beatrix_jupyterlab NA brotli NA cachetools 5.3.1 certifi 2024.02.02 cffi 1.16.0 charset_normalizer 3.3.2 cloudpickle 2.2.1 colorama 0.4.6 comm 0.1.4 cryptography 41.0.7 cycler 0.10.0 cython_runtime NA dateutil 2.8.2 db_dtypes 1.1.1 debugpy 1.8.0 decorator 5.1.1 defusedxml 0.7.1 deprecated 1.2.14 exceptiongroup 1.2.0 executing 2.0.1 fastjsonschema NA fqdn NA google NA greenlet 2.0.2 grpc 1.58.0 grpc_status NA h5py 3.10.0 idna 3.6 igraph 0.10.8 importlib_metadata NA ipykernel 6.28.0 ipython_genutils 0.2.0 ipywidgets 8.1.1 isoduration NA jedi 0.19.1 jinja2 3.1.2 joblib 1.3.2 json5 NA jsonpointer 2.4 jsonschema 4.20.0 jsonschema_specifications NA jupyter_events 0.9.0 jupyter_server 2.12.1 jupyterlab_server 2.25.2 jwt 2.8.0 kiwisolver 1.4.5 leidenalg 0.10.1 llvmlite 0.41.0 lz4 4.3.2 markupsafe 2.1.3 matplotlib 3.8.0 matplotlib_inline 0.1.6 mpl_toolkits NA mpmath 1.3.0 natsort 8.4.0 nbformat 5.9.2 numba 0.58.0 opentelemetry NA overrides NA packaging 23.2 parso 0.8.3 pexpect 4.8.0 pickleshare 0.7.5 pkg_resources NA platformdirs 4.1.0 plotly 5.18.0 prettytable 3.9.0 prometheus_client NA prompt_toolkit 3.0.42 proto NA psutil NA ptyprocess 0.7.0 pure_eval 0.2.2 pyarrow 13.0.0 pydev_ipython NA pydevconsole NA pydevd 2.9.5 pydevd_file_utils NA pydevd_plugins NA pydevd_tracing NA pygments 2.17.2 pynvml NA pyparsing 3.1.1 pyreadr 0.5.0 pythonjsonlogger NA pytz 2023.3.post1 referencing NA requests 2.31.0 rfc3339_validator 0.1.4 rfc3986_validator 0.1.1 rpds NA scipy 1.11.4 send2trash NA shapely 1.8.5.post1 six 1.16.0 sklearn 1.3.2 sniffio 1.3.0 socks 1.7.1 sql NA sqlalchemy 2.0.21 sqlparse 0.4.4 stack_data 0.6.2 sympy 1.12 termcolor NA texttable 1.7.0 threadpoolctl 3.2.0 torch 2.1.2+cu121 torchgen NA tornado 6.3.3 tqdm 4.66.1 traitlets 5.9.0 typing_extensions NA uri_template NA urllib3 1.26.18 wcwidth 0.2.12 webcolors 1.13 websocket 1.7.0 wrapt 1.15.0 xarray 2023.12.0 yaml 6.0.1 zipp NA zmq 25.1.2 zoneinfo NA zstandard 0.22.0
----- IPython 8.19.0 jupyter_client 8.6.0 jupyter_core 5.6.1 jupyterlab 4.1.2 notebook 6.5.4 ----- Python 3.10.13 | packaged by conda-forge | (main, Dec 23 2023, 15:36:39) [GCC 12.3.0] Linux-5.15.0-1053-gcp-x86_64-with-glibc2.31 ----- Session information updated at 2024-03-11 01:16