#!/usr/bin/env python # coding: utf-8 # In[1]: import numpy as np import pandas as pd import scanpy as sc import anndata as ad import sklearn as sk from sklearn.linear_model import LogisticRegression from sklearn.metrics import classification_report # Load every available dataset. # In[2]: get_ipython().run_cell_magic('time', '', "adatas = []\nfor i, row in pd.read_csv('personal.csv').iterrows():\n try:\n adata = sc.read(f'datasets/{row.Author}_{row.Year}.h5ad')\n adata.obs['dataset'] = f'{row.Author}_{row.Year}'\n adatas.append(adata)\n except FileNotFoundError:\n pass\n \nadata = ad.concat(adatas)\n") # ## Explore the data # In[3]: import seaborn as sns import matplotlib.pyplot as plt sc.set_figure_params(dpi=100, frameon=False) # In[38]: for a in adatas: df = pd.DataFrame(a.obs.perturbation_name.value_counts()) df = df.reset_index().reset_index() df.columns = ['nth condition', 'condition', 'n_samples'] display(df.head(5)) plt.figure(figsize=(10, 4)) sns.scatterplot(data=df, x='nth condition', y='n_samples', size=1, legend=None) plt.yscale('log') plt.title(a.obs.dataset.values[0]) plt.show() # Which perturbations are in these datasets? # In[26]: df = pd.crosstab(adata.obs.perturbation_name, adata.obs.dataset) df['sum'] = df.sum(axis=1) # In[38]: pd.crosstab(adata.obs.perturbation_name, adata.obs.dataset) # In[34]: df.sort_values(by='sum', ascending=False)[:50] # In[27]: freq = ad.AnnData(df.sort_values(by='sum', ascending=False).drop('sum', axis=1)) # In[28]: freq.obs = freq.obs.reset_index() # In[29]: freq.obs.perturbation_name.values[:20] # In[39]: sc.pl.dotplot( freq[freq.obs.perturbation_name.isin(freq.obs.perturbation_name.values[:50])], var_names=freq.var_names, groupby='perturbation_name', colorbar_title='n_samples' ) # In[45]: ['a', 'b'].index('b') # Which perturbations exist across multiple datasets? # In[41]: df = pd.crosstab(adata.obs.perturbation_name, adata.obs.dataset) df = df[np.count_nonzero(df, axis=1) > 1] df # In[5]: sns.barplot(data = df[df.columns[:3]].reset_index().melt(id_vars='perturbation_name'), x='perturbation_name', y='value', hue='dataset') plt.yscale('log') plt.legend(bbox_to_anchor = (1.6, 1.01)); # In[8]: import pubchempy as pcp # In[28]: get_ipython().run_cell_magic('time', '', "for p in adata[adata.obs.dataset == 'Srivatsan_2019'].obs.perturbation_name.unique():\n print(p)\n try:\n cid = pcp.get_cids(p, 'name', list_return='flat')\n if len(cid) > 1:\n print(cid)\n elif len(cid) == 0:\n print(f'oh no, nothing found for {p}')\n else:\n cmpd = pcp.Compound.from_cid(cid)\n if p.lower() != cmpd.synonyms[0].lower():\n print(cmpd.synonyms[:10])\n except:\n continue\n") # In[23]: cid # In[14]: cid = pcp.get_cids('panobinostat', 'name', list_return='flat')[0] cmpd = pcp.Compound.from_cid(cid) # In[16]: cmpd.synonyms # ## Train a model # Let's train the simplest model we can - a sgRNA classifier. From the weights of the model, you could extract common features of the sgRNA's effect. # In[43]: # subset data subset = adata[ adata.obs.perturbation_name.isin(['EGR1', 'IRF1', 'NCL', 'SET'])] # In[6]: # prepare train/test split from sklearn.model_selection import train_test_split test_idx, train_idx = train_test_split(subset.obs.index, test_size=.2) test = subset[test_idx] train = subset[train_idx] # In[7]: # train classifier clf = sk.linear_model.LogisticRegression() clf.fit(train.X, train.obs.perturbation_name.values) # In[10]: # evaluate prediction print(sk.metrics.classification_report( test.obs.perturbation_name.values, clf.predict(test.X) )) # Not bad for a 4-class classification problem. # # Much harder case: let's try leaving out one Norman et al. 2019 and only learn the treatments from the other two datasets. We've performed no batch correction between datasets so we would expect this to be extremely difficult for a linear model. # In[11]: test = subset[subset.obs.dataset == 'Norman_2019'] train = subset[subset.obs.dataset != 'Norman_2019'] # In[12]: clf = sk.linear_model.LogisticRegression() clf.fit(train.X, train.obs.perturbation_name.values) # In[13]: print(sk.metrics.classification_report( test.obs.perturbation_name.values, clf.predict(test.X) )) # Indeed, we see that the linear model is not at all predictive due to the batch effect across datasets.