import multicelltypist
from datetime import date
import hisepy
import numpy as np
import os
import pandas as pd
import scanpy as sc
cache_res = hisepy.download_from_project_store(
store_name = 'rds', # Reference Data Sets Project Store
file_name = 'AIFI-2024-03-11T02:09:16.856602896Z/Chromium_Human_Transcriptome_Probe_Set_v1.0_GRCh38-2020-A.probe_metadata.tsv', # File from 10x Genomics
)
probe_file = 'rds/Chromium_Human_Transcriptome_Probe_Set_v1.0_GRCh38-2020-A.probe_metadata.tsv'
flex_probes = pd.read_csv(probe_file, sep = '\t')
flex_genes = flex_probes['gene_name'].unique().tolist()
len(flex_genes)
18529
def read_adata_uuid(h5ad_uuid):
h5ad_path = '/home/jupyter/cache/{u}'.format(u = h5ad_uuid)
if not os.path.isdir(h5ad_path):
hise_res = hisepy.reader.cache_files([h5ad_uuid])
h5ad_filename = os.listdir(h5ad_path)[0]
h5ad_file = '{p}/{f}'.format(p = h5ad_path, f = h5ad_filename)
adata = sc.read_h5ad(h5ad_file)
return adata
def resample_anndata_min_max(adata, label_column, max_cells=None, min_cells=None, random_state = 3030):
"""
Resamples an AnnData object based on the cell labels, with the option to resample with
replacement for classes below a specified threshold.
Parameters:
ad (AnnData): The AnnData object to be resampled.
label_column (str): The column in ad.obs where the labels are stored.
max_cells (int, optional): The maximum number of cells to keep per label. If None, no upper limit is applied.
min_cells (int, optional): The minimum number of cells below which resampling with replacement occurs. If None, no lower limit is applied.
random_state (int, default = 3030): An integer used to set the state of the numpy.random.Generator
Returns:
AnnData: The resampled AnnData object.
"""
labels = adata.obs[label_column].unique()
subsets = []
rng = np.random.default_rng(random_state)
for label in labels:
# Subset AnnData object for the current label
subset = adata.obs[adata.obs[label_column] == label]
# Resample with replacement if the number of cells is below min_cells and min_cells is defined
if min_cells is not None and subset.shape[0] < min_cells:
subset = subset.sample(min_cells, replace = True, random_state = rng)
# Resample without replacement if the number of cells is greater than max_cells and max_cells is defined
elif max_cells is not None and subset.shape[0] > max_cells:
subset = subset.sample(max_cells, replace = False, random_state = rng)
subsets.append(subset)
# Concatenate all subsets
resampled_obs = pd.concat(subsets)
resampled_adata = adata[resampled_obs.index]
resampled_adata.obs_names_make_unique()
return resampled_adata
label_column = 'AIFI_L1'
max_cell_number = 100000
h5ad_uuid = '6e8972a5-9463-4230-84b4-a20de055b9c3'
adata = read_adata_uuid(h5ad_uuid)
downloading fileID: 6e8972a5-9463-4230-84b4-a20de055b9c3 Files have been successfully downloaded!
adata.shape
(1823666, 1261)
adata.obs[label_column].value_counts()
AIFI_L1 T cell 1152286 Monocyte 327919 B cell 160632 NK cell 147761 DC 23287 Platelet 7903 Progenitor cell 1526 Erythrocyte 1508 ILC 844 Name: count, dtype: int64
adata_subset = resample_anndata_min_max(
adata,
label_column,
max_cells = max_cell_number,
random_state = 3030
)
adata_subset.shape
(435068, 1261)
adata_subset.obs[label_column].value_counts()
AIFI_L1 B cell 100000 Monocyte 100000 NK cell 100000 T cell 100000 DC 23287 Platelet 7903 Progenitor cell 1526 Erythrocyte 1508 ILC 844 Name: count, dtype: int64
adata_subset = adata_subset.raw.to_adata()
adata_subset.shape
(435068, 33538)
sc.pp.normalize_total(adata_subset, target_sum=1e4)
sc.pp.log1p(adata_subset)
WARNING: adata.X seems to be already log-transformed.
keep_var = adata_subset.var.index.isin(flex_genes)
sum(keep_var)
18329
adata_subset = adata_subset[:,keep_var]
adata_subset.shape
(435068, 18329)
sc.pp.normalize_total(adata_subset, target_sum=1e4)
sc.pp.log1p(adata_subset)
WARNING: adata.X seems to be already log-transformed.
model_fs = multicelltypist.train(
adata_subset,
label_column,
n_jobs = 60,
max_iter = 10,
multi_class = 'ovr',
use_SGD = True
)
🍳 Preparing data before training ✂️ 1223 non-expressed genes are filtered out 🔬 Input data has 435068 cells and 17106 genes ⚖️ Scaling input data 🏋️ Training data using SGD logistic regression ⚠️ Warning: it may take a long time to train this dataset with 435068 cells and 17106 genes, try to downsample cells and/or restrict genes to a subset (e.g., hvgs) ✅ Model training done!
Detected genes:
df = adata_subset.X.toarray()
flag = df.sum(axis = 0) == 0
gene = adata_subset.var_names[ ~flag]
Features with high absolute classifier coefficients for each cell type class
np.argpartition
will take the coefficient scores for each class, and retrieve the positions of the highest absolute coefficient scores to the end of an array of positions. We then select the top_n
positions from the end of our array of positions, which allow us to retrieve genes with the highest absolute coefficients for each class.
We can then combine these to get a unique list of genes that are important for our model.
top_n = 200
gene_index = np.argpartition(
np.abs(model_fs.classifier.coef_),
-top_n,
axis = 1
)
gene_index = gene_index[:, -top_n:]
gene_index = np.unique(gene_index)
print('Number of genes selected: {n}'.format(n = len(gene_index)))
Number of genes selected: 1102
selected_genes = gene[gene_index.tolist()]
selected_df = pd.DataFrame({'gene': selected_genes})
selected_df.head()
gene | |
---|---|
0 | HES4 |
1 | TTLL10 |
2 | TNFRSF18 |
3 | TNFRSF4 |
4 | TNFRSF25 |
adata = adata.raw.to_adata()
sc.pp.normalize_total(adata, target_sum=1e4)
sc.pp.log1p(adata)
WARNING: adata.X seems to be already log-transformed.
adata = adata[:, adata.var_names.isin(selected_genes)]
adata.shape
(1823666, 1102)
model_fs = multicelltypist.train(
adata,
label_column,
n_jobs = 60,
max_iter = 100,
multi_class = 'ovr',
check_expression = False
)
🍳 Preparing data before training 🔬 Input data has 1823666 cells and 1102 genes ⚖️ Scaling input data 🏋️ Training data using logistic regression ✅ Model training done!
out_dir = 'output'
if not os.path.isdir(out_dir):
os.makedirs(out_dir)
out_genes = 'output/ref_pbmc_clean_celltypist_top{n}_flex-features_{l}_{d}.csv'.format(
n = top_n,
l = label_column,
d = date.today()
)
selected_df.to_csv(out_genes)
out_model = 'output/ref_pbmc_clean_celltypist_model_flex-features_{l}_{d}.pkl'.format(
l = label_column,
d = date.today()
)
model_fs.write(out_model)
Finally, we'll use hisepy.upload.upload_files()
to send a copy of our output to HISE to use for downstream analysis steps.
study_space_uuid = '64097865-486d-43b3-8f94-74994e0a72e0'
title = 'PBMC Reference {l} CellTypist Model Flex Features {d}'.format(
l = label_column,
d = date.today()
)
in_files = [h5ad_uuid]
in_files
['6e8972a5-9463-4230-84b4-a20de055b9c3']
out_files = [out_genes, out_model]
out_files
['output/ref_pbmc_clean_celltypist_top200_flex-features_AIFI_L1_2024-03-11.csv', 'output/ref_pbmc_clean_celltypist_model_flex-features_AIFI_L1_2024-03-11.pkl']
hisepy.upload.upload_files(
files = out_files,
study_space_id = study_space_uuid,
title = title,
input_file_ids = in_files
)
output/ref_pbmc_clean_celltypist_top200_flex-features_AIFI_L1_2024-03-11.csv output/ref_pbmc_clean_celltypist_model_flex-features_AIFI_L1_2024-03-11.pkl you are trying to upload file_ids... ['output/ref_pbmc_clean_celltypist_top200_flex-features_AIFI_L1_2024-03-11.csv', 'output/ref_pbmc_clean_celltypist_model_flex-features_AIFI_L1_2024-03-11.pkl']. Do you truly want to proceed?
{'trace_id': '3768e951-13f8-4aac-9c46-908d128c6e3c', 'files': ['output/ref_pbmc_clean_celltypist_top200_flex-features_AIFI_L1_2024-03-11.csv', 'output/ref_pbmc_clean_celltypist_model_flex-features_AIFI_L1_2024-03-11.pkl']}
import session_info
session_info.show()
----- anndata 0.10.3 hisepy 0.3.0 multicelltypist 1.6.2 numpy 1.25.2 pandas 2.1.4 scanpy 1.9.6 session_info 1.0.0 -----
PIL 10.0.1 anyio NA arrow 1.3.0 asttokens NA attr 23.2.0 attrs 23.2.0 babel 2.14.0 beatrix_jupyterlab NA brotli NA cachetools 5.3.1 certifi 2024.02.02 cffi 1.16.0 charset_normalizer 3.3.2 cloudpickle 2.2.1 colorama 0.4.6 comm 0.1.4 cryptography 41.0.7 cycler 0.10.0 cython_runtime NA dateutil 2.8.2 db_dtypes 1.1.1 debugpy 1.8.0 decorator 5.1.1 defusedxml 0.7.1 deprecated 1.2.14 exceptiongroup 1.2.0 executing 2.0.1 fastjsonschema NA fqdn NA google NA greenlet 2.0.2 grpc 1.58.0 grpc_status NA h5py 3.10.0 idna 3.6 igraph 0.10.8 importlib_metadata NA ipykernel 6.28.0 ipython_genutils 0.2.0 ipywidgets 8.1.1 isoduration NA jedi 0.19.1 jinja2 3.1.2 joblib 1.3.2 json5 NA jsonpointer 2.4 jsonschema 4.20.0 jsonschema_specifications NA jupyter_events 0.9.0 jupyter_server 2.12.1 jupyterlab_server 2.25.2 jwt 2.8.0 kiwisolver 1.4.5 leidenalg 0.10.1 llvmlite 0.41.0 lz4 4.3.2 markupsafe 2.1.3 matplotlib 3.8.0 matplotlib_inline 0.1.6 mpl_toolkits NA mpmath 1.3.0 natsort 8.4.0 nbformat 5.9.2 numba 0.58.0 opentelemetry NA overrides NA packaging 23.2 parso 0.8.3 pexpect 4.8.0 pickleshare 0.7.5 pkg_resources NA platformdirs 4.1.0 plotly 5.18.0 prettytable 3.9.0 prometheus_client NA prompt_toolkit 3.0.42 proto NA psutil NA ptyprocess 0.7.0 pure_eval 0.2.2 pyarrow 13.0.0 pydev_ipython NA pydevconsole NA pydevd 2.9.5 pydevd_file_utils NA pydevd_plugins NA pydevd_tracing NA pygments 2.17.2 pynvml NA pyparsing 3.1.1 pyreadr 0.5.0 pythonjsonlogger NA pytz 2023.3.post1 referencing NA requests 2.31.0 rfc3339_validator 0.1.4 rfc3986_validator 0.1.1 rpds NA scipy 1.11.4 send2trash NA shapely 1.8.5.post1 six 1.16.0 sklearn 1.3.2 sniffio 1.3.0 socks 1.7.1 sql NA sqlalchemy 2.0.21 sqlparse 0.4.4 stack_data 0.6.2 sympy 1.12 termcolor NA texttable 1.7.0 threadpoolctl 3.2.0 torch 2.1.2+cu121 torchgen NA tornado 6.3.3 tqdm 4.66.1 traitlets 5.9.0 typing_extensions NA uri_template NA urllib3 1.26.18 wcwidth 0.2.12 webcolors 1.13 websocket 1.7.0 wrapt 1.15.0 xarray 2023.12.0 yaml 6.0.1 zipp NA zmq 25.1.2 zoneinfo NA zstandard 0.22.0
----- IPython 8.19.0 jupyter_client 8.6.0 jupyter_core 5.6.1 jupyterlab 4.1.2 notebook 6.5.4 ----- Python 3.10.13 | packaged by conda-forge | (main, Dec 23 2023, 15:36:39) [GCC 12.3.0] Linux-5.15.0-1053-gcp-x86_64-with-glibc2.31 ----- Session information updated at 2024-03-11 19:30