!pip install datasets umap-learn &> /dev/null # Import libraries from pathlib import Path import os import pickle import numpy as np import pandas as pd import re from functools import partial import matplotlib.pyplot as plt import seaborn as sns sns.set_context(context="paper", font_scale=1.5) # Wild type sequence of Rep78 from AAV2 rep78_seq = 'MPGFYEIVIKVPSDLDGHLPGISDSFVNWVAEKEWELPPDSDMDLNLIEQAPLTVAEKLQRDFLTEWRRVSKAPEALFFVQFEKGESYFHMHVLVETTGVKSMVLGRFLSQIREKLIQRIYRGIEPTLPNWFAVTKTRNGAGGGNKVVDECYIPNYLLPKTQPELQWAWTNMEQYLSACLNLTERKRLVAQHLTHVSQTQEQNKENQNPNSDAPVIRSKTSARYMELVGWLVDKGITSEKQWIQEDQASYISFNAASNSRSQIKAALDNAGKIMSLTKTAPDYLVGQQPVEDISSNRIYKILELNGYDPQYAASVFLGWATKKFGKRNTIWLFGPATTGKTNIAEAIAHTVPFYGCVNWTNENFPFNDCVDKMVIWWEEGKMTAKVVESAKAILGGSKVRVDQKCKSSAQIDPTPVIVTSNTNMCAVIDGNSTTFEHQQPLQDRMFKFELTRRLDHDFGKVTKQEVKDFFRWAKDHVVEVEHEFYVKKGGAKKRPAPSDADISEPKRVRESVAQPSTSDAEASINYADRYQNKCSRHVGMNLMLFPCRQCERMNQNSNICFTHGQKDCLECFPVSESQPVSVVKKAYQKLCYIHHIMGKVPDACTACDLVNVDLDDCIFEQ' len(rep78_seq) # change M225G. This will be considered wild type sequence from here on. Python string is zero-indexed wt_seq = rep78_seq[:224] + 'G' + rep78_seq[225:] len(wt_seq) # ensure length remains same after substitution # Link to csv file with fitness data from paper rep7868_selection_values_barcode = 'https://raw.githubusercontent.com/churchlab/aav_rep_scan/master/analysis/selection_values/rep7868_selection_values_barcode.csv' # Read the fitness data into a pandas dataframe # a, b = fitness values for transfection duplicates fitness_data = pd.read_csv(rep7868_selection_values_barcode) fitness_data # Remove rows containing stop codon (truncated proteins smaller than 621 AA) # And remove rows with data for 622 aminoacid position (rep78 is only 621 AA long) fitness_data_sub = fitness_data.loc[-((fitness_data['aa'] == '*')|(fitness_data['abs_pos'] == 622))].copy() fitness_data_sub # check if any fitness values are inf fitness_data_sub.loc[(fitness_data_sub['a'] == np.inf)|(fitness_data_sub['b'] == np.inf)].head() # number of rows where fitness values are inf len(fitness_data_sub.loc[(fitness_data_sub['a'] == np.inf)|(fitness_data_sub['b'] == np.inf)]) # check if any fitness values are NaN fitness_data_sub.loc[(fitness_data_sub['a'] == np.nan)|(fitness_data_sub['b'] == np.nan)] # Replacing infinite with nan fitness_data_sub.replace([np.inf, -np.inf], np.nan, inplace=True) # Dropping all the rows with nan values fitness_data_sub.dropna(inplace=True) fitness_data_sub # 40 rows containing infinite fitness values should be dropped # Group fitness values for a and b (transfection replicates) # by a combination of abs_pos (aminoacid position) and aa (aminoacid) fitness_data_grouped = fitness_data_sub.groupby(['abs_pos', 'aa'], as_index=False).agg({'a': lambda x: list(x), 'b': lambda x: list(x)}) fitness_data_grouped # New `fitness_values` column contains all the fitness values for a given aminoacid substitution at a position fitness_data_grouped['fitness_values'] = fitness_data_grouped["a"] + fitness_data_grouped["b"] fitness_data_grouped.head() len(fitness_data_grouped['fitness_values'][0]) == len(fitness_data_grouped["a"][0]) + len(fitness_data_grouped["b"][0]) def is_wt_aa(wt_seq, abs_pos, aa): """ If a given aminoacid substitution is wildtype return 1, else 0. Example M1M will return 1, M1A will return 0 """ return int(wt_seq[abs_pos-1] == aa) def aa_mutations(wt_seq, abs_pos, aa): """ Return mutation string. Example M1A. """ return wt_seq[abs_pos-1] + str(abs_pos) + aa # If a given aminoacid substitution is same as wildtype return 1, else 0. fitness_data_grouped['is_wt_aa'] = fitness_data_grouped.apply(lambda x: is_wt_aa(wt_seq, x['abs_pos'], x['aa']), axis=1) # Return mutation string, an example M1A. fitness_data_grouped['aa_mutations'] = fitness_data_grouped.apply(lambda x: aa_mutations(wt_seq, x['abs_pos'], x['aa']), axis=1) fitness_data_grouped # All the fitness values for wildtype sequences fitness_values_wt = np.concatenate(fitness_data_grouped.query('is_wt_aa == 1')['fitness_values'].values) # Fitness values for all the mutant sequences fitness_values_mutant = np.concatenate(fitness_data_grouped.query('is_wt_aa == 0')['fitness_values'].values) # Density plot of fitness values for wildtype vs mutants sns.histplot(np.log10(fitness_values_wt), label = "wildtype sequences", kde=True) sns.histplot(np.log10(fitness_values_mutant), label = "mutant sequences", kde=True) plt.xlabel('log10(fitness)') plt.ylabel('Density') plt.title("Distribution of fitness values: wild-type vs mutant sequences") plt.legend(); # Density plot of fitness values for wildtype vs mutants sns.histplot(fitness_values_wt, label = "wild-type sequences", kde=True) sns.histplot(fitness_values_mutant, label = "mutant sequences", kde=True) plt.xlabel('fitness') plt.ylabel('Density') plt.legend() plt.title("Distribution of fitness values: wild-type vs mutant sequences"); # Find the row with max wildtype fitness value fitness_data_sub.loc[(fitness_data_sub['a'] == max(fitness_values_wt))| (fitness_data_sub['b'] == max(fitness_values_wt))] # Find the row with max mutant fitness value fitness_data_sub.loc[(fitness_data_sub['a'] == max(fitness_values_mutant))| (fitness_data_sub['b'] == max(fitness_values_mutant))] # Put all the fitness values for wildtype sequence in one row and make a new dataframe # subset all mutant fitness values fitness_by_mutation = fitness_data_grouped.query('is_wt_aa == 0')[['aa_mutations', 'fitness_values']] fitness_by_mutation.loc[-1] = ["", fitness_values_wt] # insert wildtype fitness values in top row fitness_by_mutation = fitness_by_mutation.sort_index() # sorting by index fitness_by_mutation # Inspired from https://github.com/ElArkk/jax-unirep/blob/master/paper/gfp_prediction.ipynb def mut2seq(mutation_string, wt_sequence, delimiter=":"): """ Reconstruct full mutant sequence given mutation string. Example mutation_string: - P2C - P2T; G3A - A111T; Q194R; N249I; N251Y; H255Y Example wt_seq: 'MPGFYEIVIKVP' Example Output for mutation_string P2C: 'MPGFYEIVIKVP' -> 'MCGFYEIVIKVP' """ pattern = '([A-Z])(\d+)([A-Z])' if mutation_string is None or mutation_string == "": return wt_sequence mutations = mutation_string.split(delimiter) mutant_sequence = wt_sequence # mutant_sequence is a list for mut in mutations: match = re.search(pattern, mut) # Ensure mutation string matches regex pattern if not match: raise ValueError(f""" The mutation string {mut} is invalid.""") else: position = int(match.group(2)) # Return mutation position from mutation string. letter = match.group(3) # Return mutation letter from mutation string. if position == 0: raise ValueError( f""" The mutation string {mut} is invalid. It has "0" as its position. """ ) if position > len(wt_sequence): raise ValueError( f""" The mutation string {mut} is invalid. Its position is greater than the length of the WT sequence. """ ) mutant_sequence = mutant_sequence[:position-1] + letter + mutant_sequence[position:] # -1 is necessary because the python string is zero-indexed return mutant_sequence mut2rep78 = partial(mut2seq, wt_sequence=wt_seq) fitness_by_mutation['sequence'] = fitness_by_mutation['aa_mutations'].apply(mut2rep78) fitness_by_mutation['num_fitval'] = fitness_by_mutation['fitness_values'].apply(len) fitness_by_mutation['median_fitness'] = fitness_by_mutation['fitness_values'].apply(np.median) fitness_by_mutation['std'] = fitness_by_mutation['fitness_values'].apply(np.std) fitness_by_mutation["log10_fitness"] = fitness_by_mutation['median_fitness'].apply(np.log10) fitness_by_mutation # Ensure all the sequence are same length fitness_by_mutation['sequence'].apply(len).unique() # an example query fitness_by_mutation.query("aa_mutations.str.contains(r'[A-Z]2[A-Z]')").head(1) !mkdir -p data # Save the pandas dataframe using `pickle` system with open('data/fitness_by_mutation_rep7868aav2.pkl', 'wb') as f: pickle.dump(fitness_by_mutation, f) # Load the Rep78 DMS data for supervised training saved as pkl file # with open('data/fitness_by_mutation_rep7868aav2.pkl', 'rb') as f: fitness_by_mutation = pickle.load(f) fitness_by_mutation.head() sns.histplot(np.log10(fitness_by_mutation['median_fitness'].values), kde=True) plt.xlabel('log10(fitness)') plt.ylabel('Density') plt.title("Distribution of median fitness values"); sns.histplot(fitness_by_mutation['median_fitness'].values, kde=True) plt.xlabel('fitness') plt.ylabel('Density') plt.title("Distribution of median fitness values"); # median_fitness_wt value is used to normalize the entire dataset # such that wt fitness corresponds to a value of 1 median_fitness_wt = fitness_by_mutation.loc[-1, 'median_fitness'] sequences = fitness_by_mutation["sequence"].tolist() fitness = fitness_by_mutation.loc[:, 'median_fitness'].values # obtain target # transformation to normalize WT fitness value to 1 fitness_norm = (fitness - np.min(fitness))/(median_fitness_wt -np.min(fitness)) sns.histplot(fitness_norm, kde=True) plt.xlabel('fitness') plt.ylabel('Density') plt.title("Distribution of median fitness values: normalized"); esm1v_checkpoint = "facebook/esm1v_t33_650M_UR90S_1" # The AutoTokenizer class automatically retrieve the model's configuration, pretrained weights, # or vocabulary from the name of the checkpoint. from transformers import AutoTokenizer tokenizer = AutoTokenizer.from_pretrained(esm1v_checkpoint) # tokenize a single sequence tokenizer(sequences[0]) tokz = tokenizer(sequences) from datasets import Dataset ds = Dataset.from_dict(tokz) ds ds = ds.add_column("labels", fitness_norm) ds print(ds[0]); # Similar to the AutoTokenizer class, AutoModel has a from_pretrained() method # to load the weights of a pretrained model from transformers import AutoModel import torch esm1v_checkpoint = "facebook/esm1v_t33_650M_UR90S_1" device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model = AutoModel.from_pretrained(esm1v_checkpoint).to(device) # Retrieve the last hidden states for a single string text = sequences[0] inputs = tokenizer(text, return_tensors="pt") print(f"Input tensor shape: {inputs['input_ids'].size()}") inputs = {k:v.to(device) for k,v in inputs.items()} with torch.no_grad(): outputs = model(**inputs) print(outputs) outputs.last_hidden_state.size() outputs.last_hidden_state.mean(dim=1) torch.mean(outputs.last_hidden_state, dim=1).size() def extract_hidden_states(batch): # Place model inputs on the GPU inputs = {k:v.to(device) for k,v in batch.items() if k in tokenizer.model_input_names} # Extract last hidden states with torch.no_grad(): last_hidden_state = model(**inputs).last_hidden_state # Return vector for [CLS] token return {"hidden_state": torch.mean(last_hidden_state, dim=1).cpu().numpy()} ds.set_format("torch", columns=["input_ids", "attention_mask", "labels"]) torch.cuda.is_available() # use GPU for this step else very slow ds = ds.map(extract_hidden_states, batched=True, batch_size = 64) with open('data/rep7868aav2_emb_esm1v.pkl', 'wb') as f: pickle.dump(ds, f) ds X = np.array(ds["hidden_state"]) y = np.array(ds["labels"]) X.shape, len(y) from sklearn.model_selection import train_test_split # Split the dataset X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42) X_train.shape, X_test.shape, len(y_train), len(y_test) sns.histplot(y_train, kde=True) plt.xlabel('fitness') plt.ylabel('Density') plt.title("Distribution of train fitness values"); sns.histplot(y_test, kde=True) plt.xlabel('fitness') plt.ylabel('Density') plt.title("Distribution of test fitness values"); from sklearn.decomposition import PCA num_pca_components = 60 pca = PCA(num_pca_components) X_train_pca = pca.fit_transform(X_train) fig_dims = (7, 6) fig, ax = plt.subplots(figsize=fig_dims) sc = ax.scatter(X_train_pca[:,0], X_train_pca[:,1], c=y_train, marker='.') ax.set_xlabel('PCA first principal component') ax.set_ylabel('PCA second principal component') plt.colorbar(sc, label='Variant Effect'); from umap import UMAP from sklearn.preprocessing import MinMaxScaler def get_umap_embedding(features, umap_params): # Initialize UMAP reducer = UMAP(random_state=7, **umap_params) # Fit UMAP mapper = reducer.fit(features) embedding = mapper.transform(features) return embedding umap_params = {"n_neighbors": 15, "min_dist": 0.1, "metric": "euclidean"} import matplotlib.colors as colors fig, ax = plt.subplots(nrows=1, ncols=1, figsize=(6, 6)) # Scale features to [0,1] range X_scaled = MinMaxScaler().fit_transform(X_train) embedding = get_umap_embedding(X_scaled, umap_params=umap_params) divnorm = colors.TwoSlopeNorm(vmin=y_train.min(), vcenter=1, vmax=y_train.max()) ax.scatter(embedding[:, 0], embedding[:, 1], c=y_train, cmap="coolwarm", norm=divnorm, s=10.0, alpha=1.0, marker="o", linewidth=0) ax.set(xlabel="UMAP Dim 0", ylabel="UMAP Dim 1") plt.tight_layout() plt.show(); from sklearn.linear_model import RidgeCV from sklearn.model_selection import KFold from sklearn.preprocessing import StandardScaler from sklearn.metrics import mean_squared_error, PredictionErrorDisplay, ndcg_score from sklearn.pipeline import make_pipeline %%capture # different alphas (regularization parameter) to try alphas = np.logspace(-6, 6, 100) kfold = KFold(n_splits=5, random_state=42, shuffle=True) # K-Folds cross-validator # define model ridgecv = make_pipeline(StandardScaler(), RidgeCV(alphas=alphas, scoring='neg_mean_squared_error', cv=kfold)) # fit model ridgecv.fit(X_train, y_train) ridge_best_alpha = ridgecv[1].alpha_ # make predictions preds_train = ridgecv.predict(X_train) preds_test = ridgecv.predict(X_test); # Regression from scipy import stats spearmanr = stats.spearmanr(a=preds_test, b=y_test, axis=0) print(spearmanr) print(ridge_best_alpha) from sklearn.linear_model import Ridge from sklearn.model_selection import cross_val_score kfold = KFold(n_splits=5, random_state=42, shuffle=True) # K-Folds cross-validator def RidgeSR(best_alpha, X_train, y_train, cv=kfold): ''' Perform a post hoc procedure to choose the strongest regularization such that the cross-validation performance was still statistically equal (by t-test) to the level of regularization we would select through normal cross-validation. This procedure selects a stronger regularization than what would be obtained using the ‘RidgeCV’ procedure implemented before. Input: `best_alpha`: best regularization parameter alpha from RidgeCV procedure ''' alphas = np.linspace(best_alpha, best_alpha*20, 20) ridge = make_pipeline(StandardScaler(), Ridge()) p_values = [] mse = [] for a in alphas: ridge[1].set_params(alpha = a) scores_posthoc = cross_val_score(ridge, X_train, y_train, scoring='neg_mean_squared_error', cv=kfold, n_jobs=-1) scores_posthoc = np.absolute(scores_posthoc) if a == best_alpha: scores = scores_posthoc mse.append(np.mean(scores_posthoc)) p_values.append(stats.ttest_ind(scores, scores_posthoc).pvalue) # choose alpha higher than best alpha # choose a stronger alpha than would be selected by RidgeCV such that cross validation performance # is still equal to best alpha from RidgeCV # When p value < 0.10 MSE distribution from stronger alpha is considered the same as from best alpha alpha_ridgsesr = alphas[np.where(1-np.asarray(p_values) <= 0.90)[0][-1]] mse_ridgesr = mse[np.where(1-np.asarray(p_values) <= 0.90)[0][-1]] # Plot alpha vs MSE and alpha vs p-value fig, (ax0, ax1) = plt.subplots(figsize=(12, 6), nrows=1, ncols=2) ax0.plot(alphas, mse, 'ro-') ax0.set(xlabel='alphas', ylabel='mean squared error') ax1.plot(alphas, p_values, 'bo-') ax1.set(xlabel='alphas', ylabel='p value') fig.suptitle(f'RidgeSR alpha: {alpha_ridgsesr}') plt.tight_layout() plt.show(); return alpha_ridgsesr # best alpha from RidgeCV procedure best_alpha = ridge_best_alpha # selects a stronger regularization than what would be obtained using the ‘RidgeCV’ procedure previously defined ridgsesr_alpha = RidgeSR(best_alpha, X_train, y_train, cv=kfold) # define model ridge = make_pipeline(StandardScaler(), Ridge(alpha=ridgsesr_alpha)) # fit model ridge.fit(X_train, y_train) # make predictions preds_train = ridge.predict(X_train) preds_test = ridge.predict(X_test) print(stats.spearmanr(a=preds_test, b=y_test, axis=0)) print(ridgsesr_alpha) from sklearn.ensemble import BaggingRegressor # define ridge model with RidgeSR alpha ridge = make_pipeline(StandardScaler(), Ridge(alpha=ridgsesr_alpha)) # define bag model bag = BaggingRegressor(estimator=ridge, n_estimators=100, max_samples=1.0, max_features=1.0, bootstrap=True, bootstrap_features=False, n_jobs=-1) # fit the data bag.fit(X_train, y_train) # make predictions preds_train = bag.predict(X_train) preds_test = bag.predict(X_test) spearmanr = stats.spearmanr(a=preds_test, b=y_test, axis=0) print(spearmanr) from sklearn.metrics import PredictionErrorDisplay fig, (ax0, ax1) = plt.subplots(figsize=(12, 6), nrows=1, ncols=2) # plot actual vs predicted values PredictionErrorDisplay.from_predictions( y_test, preds_test, ax=ax0, kind='actual_vs_predicted', scatter_kwargs={"alpha":0.5} ) ax0.plot([], [], " ", label=f"Spearman r: {np.round(spearmanr.statistic, 4)}") ax0.legend(loc="best") ax0.axis('tight') PredictionErrorDisplay.from_predictions( y_test, preds_test, kind='residual_vs_predicted', ax=ax1, scatter_kwargs={"alpha":0.5} ) ax1.plot([], [], " ", label=f"Spearman r: {np.round(spearmanr.statistic, 4)}") ax1.legend(loc="best") ax1.axis('tight') plt.tight_layout() plt.show(); # define ridge model with RidgeSR alpha ridge = make_pipeline(StandardScaler(), Ridge(alpha=ridgsesr_alpha)) # define bag model bag_emb_esm1v = BaggingRegressor(estimator=ridge, n_estimators=100, max_samples=1.0, max_features=1.0, bootstrap=True, bootstrap_features=False, oob_score=True, n_jobs=-1) # No train test split. Fit the bag model on entire dataset bag_emb_esm1v.fit(X, y) # make predictions preds_oob_esm1v = bag_emb_esm1v.oob_prediction_ ! mkdir -p models # save model with open('models/bag_emb_esm1v_rep7868aav2.pkl','wb') as f: pickle.dump(bag_emb_esm1v,f) print(stats.spearmanr(a=preds_oob_esm1v, b=y, axis=0)) ds_with_preds = ds.add_column("predicted_label", preds_oob_esm1v) ds_with_preds # create a DataFrame with the input_ids, losses, and predicted/true labels: ds_with_preds.set_format("pandas") cols = ['input_ids', 'labels', 'predicted_label'] df_with_preds = ds_with_preds[:][cols] df_with_preds.head() # Load the Rep78 DMS data for supervised training saved as pkl file # with open('data/fitness_by_mutation_rep7868aav2.pkl', 'rb') as f: fitness_by_mutation = pickle.load(f) sequences = fitness_by_mutation["sequence"].tolist() df_with_preds['sequence'] = sequences df_with_preds.head() from functools import partial def find_mutated_aa(best_seq, starter_seq): """Return aminoacid substitution between two protein sequences of the same length""" mutated_aa = [starter_seq[i]+str(i+1)+best_seq[i] for i in range(len(starter_seq)) if starter_seq[i] != best_seq[i]] return mutated_aa mut_in_rep78 = partial(find_mutated_aa, starter_seq=wt_seq) df_with_preds['mutated_aa'] = df_with_preds['sequence'].apply(mut_in_rep78) df_with_preds df_with_preds.sort_values(by = ['predicted_label', 'labels'], ascending = [False, False]).head(100) df_with_preds.sort_values(by = ['predicted_label', 'labels'], ascending = [False, False]).to_csv('data/rep7868aav2_preds_emb_esm1v.csv') # Download model from google.colab import files files.download("models/bag_emb_esm1v_rep7868aav2.pkl") !md5sum models/bag_emb_esm1v_rep7868aav2.pkl # download all files from colab data folder !zip -r /content/data.zip /content/data files.download('/content/data.zip')