#!/usr/bin/env python # coding: utf-8 # In[ ]: import os import sys import copy import requests import numpy as np import pandas as pd from scipy import stats from pathlib import Path from tqdm.notebook import trange from tqdm.notebook import tqdm from urllib.request import urlretrieve # In[ ]: # git for functions loading and work path finding import git repo = git.Repo('.', search_parent_directories=True) work_path = Path(repo.working_tree_dir) if str(work_path) not in sys.path: sys.path.append(str(work_path)) # In[ ]: from function.seqfilter import SeqFilter from function.utilities import seq_aa_check from function.cutpondr import CutPONDR from function.ebi import EbiAPI # plot import plotly import plotly.graph_objects as go from plotly.graph_objs import Layout # attention map decorator # https://github.com/luo3300612/Visualizer from visualizer import get_local get_local.activate() # # 1. Load pretrain model # In[ ]: # download pretrained weight from OSF: https://osf.io/jk29b/ pretrain_model_path = work_path / 'trained_weight.pt' if not pretrain_model_path.is_file(): url = 'https://osf.io/y2jh8/download' urlretrieve(url, str(pretrain_model_path)) # In[ ]: # load model architecture from 2_model_training.ipynb by IPython's magic command pretrain_ipynb = str(work_path / '2_model_training.ipynb') # In[ ]: get_ipython().run_line_magic('run', '$pretrain_ipynb') # In[ ]: moco = moco_builder.MoCo(base_encoder=AttenTorchScratch, dim=embed_dim, mlp_dim=moco_mlp_dim, T=nce_temp) moco.load_state_dict(torch.load(pretrain_model_path, map_location=torch.device('cpu'))) base_encoder = copy.deepcopy(moco.base_encoder) base_encoder = base_encoder.eval() # # 2. Inference # ## 2.1 Function for attention map visualization # In[ ]: # initialize PONDR web crawler and sequence length function pondr_driver = CutPONDR() seqfilter = SeqFilter() # attention score normalize function def minmaxnormalize(data): return (data - np.min(data)) / (np.max(data) - np.min(data)) # exception for no disorder region class NoDisorderRegion(Exception): pass # attention map visualization class SeqEncodeForPlot(): def __init__(self, input_type, sequence): ''' input_type = ['uniprot','custom'] sequence = ['Q13148','AAAAAAAA'] ''' # get seq frag if input_type == 'uniprot': frag_out = self.__get_vis_seq_uid(sequence) elif input_type == 'custom': frag_out = self.__get_vis_seq_custom(sequence) self.title = frag_out['title'] self.frag_seq = frag_out['frag_seq'] self.seq_length = len(frag_out['frag_seq']) self.start_mod = frag_out['start_mod'] self.end_mod = frag_out['end_mod'] # get plot label self.plot_label = self.__get_plot_label(self.frag_seq, self.start_mod, self.end_mod) # encode self.encode_seq = seqprocess.seq_process_pipe([self.frag_seq]) # get atten self.atten = self.__get_atten_from_decorator(self.encode_seq) # post process 1: get last cls atten mean self.last_cls_atten_mean = self.__get_last_cls_atten_mean(self.atten) # post process 2: cut cls self.last_cls_atten_mean_cut = self.__cut_cls(self.last_cls_atten_mean, self.seq_length) # post process 3: normalize self.last_cls_atten_mean_cut_normalize = self.__normalize_cls(self.last_cls_atten_mean_cut) # post process 4: feature sep sep_out = self.__get_feature_sep(self.frag_seq, self.last_cls_atten_mean_cut_normalize) self.sep_cls_for_plot = sep_out['sep_cls_for_plot'] self.sep_feature_label = sep_out['sep_feature_label'] ###########get seq########### def __get_vis_seq_custom(self, custom_sequence): start_mod = 1 end_mod = len(custom_sequence) return { "title": '{}, {}~{}'.format('custom', start_mod, end_mod), "frag_seq": custom_sequence, "start_mod": start_mod, "end_mod": end_mod } def __get_vis_seq_uid(self, uniprot_id): # retriving sequence seq_info = self.__get_sequence_online(uniprot_id, "VSL2", pondr_driver) gene_name = seq_info['gene_name'] od_ident = seq_info['od_ident'] protein_sequence = seq_info['protein_sequence'] # od_ident length filter od_ident = seqfilter.length_filter_by_od_ident(od_ident, disorder_filter_length=40, order_filter_length=10) od_index = seqfilter.get_od_index(od_ident)['disorder_region'] if len(od_index) == 0: raise NoDisorderRegion("This protein does not have disorder region") elif len(od_index) == 1: print("Only 1 disorder region {}, automatically use that".format(od_index[0])) od_index_i = 0 else: choose_hint = '' for index, element in enumerate(od_index): choose_hint = choose_hint + str("region {}: {}".format(index, element)) + "\n" + ' ' od_index_i = int( input("Please choose disorder region: \n {}".format(choose_hint))) # get protein seq by chosen index start = od_index[od_index_i]['start'] end = od_index[od_index_i]['end'] if (end - start) > 512: print("sequence length longer than 512: please specify the start and end") start = int(input("start: \n")) end = int(input("end: \n")) frag_seq = protein_sequence[start:end] # mod start index start_mod = start + 1 end_mod = end return { "title":'{}, {}, {}~{}'.format(uniprot_id, gene_name, start_mod, end_mod), "frag_seq":frag_seq, "start_mod":start_mod, "end_mod":end_mod } def __get_sequence_online(self, uniprot_id, od_ident_algorithm, pondr_driver): # get seqeucne by uniprot id a_protein = EbiAPI(uniprot_id) # get sequence protein_sequence = a_protein.protein_sequence protein_sequence = seq_aa_check(protein_sequence) gene_name = a_protein.gene_name print('protein: {}, gene name: {}, length: {}'.format(uniprot_id, gene_name, len(protein_sequence))) # use pondr to get od_ident print("sending to PONDR by algorithm {}...".format(od_ident_algorithm)) pondr_driver.cut(protein_sequence, protein_name='aa', algorithm=od_ident_algorithm) od_ident = pondr_driver.get_od_ident() return { "uniprot_id": uniprot_id, "gene_name": gene_name, "protein_sequence": protein_sequence, "od_ident": od_ident } ###########get seq########### # lable plot def __get_plot_label(self, frag_seq, start_mod, end_mod): label = [] for index, element in enumerate(frag_seq): label.append("{}_{}".format(index + start_mod, element)) return label # get atten def __get_atten_from_decorator(self, encoded_seq): get_local.clear() _ = base_encoder(encoded_seq, return_all_atten=True)[1] # tidy attention_map from decorator to (layer, head, length, length) atten = get_local.cache atten = np.stack(atten['scaled_dot_product_attention']).squeeze() atten = atten.reshape([num_heads, depth, atten.shape[-2], atten.shape[-2]]) # head, layer, length, length atten = np.moveaxis(atten, [0, 1], [1, 0]) # layer, head, length, length return torch.tensor(atten) ###########post precess########### # get last cls atten def __get_last_cls_atten_mean(self, atten): last_cls_atten = atten[-1, :, 0, 1:] last_cls_atten_mean = last_cls_atten.mean(dim=0, keepdim=False) last_cls_atten_mean = last_cls_atten_mean.detach().numpy() return last_cls_atten_mean # cut cls def __cut_cls(self, last_cls_atten_mean, seq_length): last_cls_atten_mean = last_cls_atten_mean[:seq_length] return last_cls_atten_mean # normalize def __normalize_cls(self, last_cls_atten_mean): last_cls_atten_mean = minmaxnormalize(last_cls_atten_mean) #change normalize way return last_cls_atten_mean # feature sep def __get_feature_sep(self, frag_seq, last_cls_atten_mean): # sep feature label sep_feature_label = [ "Hydrophobic (A I L M P V)", 'C', 'G', 'S T', 'Prion like (Q N)', 'Negative charge (D E)', 'Positive charge (R K H)', 'Aromatic (W Y F)', 'All features' ] # many conditions frag_seq = list(frag_seq) df = pd.DataFrame(frag_seq, columns=['frag_seq']) df['wyf'] = df['frag_seq'].apply(lambda x: x in ['W', 'Y', 'F']) df['rkh'] = df['frag_seq'].apply(lambda x: x in ['R', 'K', 'H']) df['de'] = df['frag_seq'].apply(lambda x: x in ['D', 'E']) df['qn'] = df['frag_seq'].apply(lambda x: x in ['Q', 'N']) df['st'] = df['frag_seq'].apply(lambda x: x in ['S', 'T']) df['g'] = df['frag_seq'].apply(lambda x: x in ['G']) df['c'] = df['frag_seq'].apply(lambda x: x in ['C']) df['lpavmi'] = df['frag_seq'].apply(lambda x: x in ['L', 'P', 'A', 'V', 'M', 'I']) all_cond_array = [] all_cond_array.append(np.where(df['lpavmi'], last_cls_atten_mean, np.nan)) all_cond_array.append(np.where(df['c'], last_cls_atten_mean, np.nan)) all_cond_array.append(np.where(df['g'], last_cls_atten_mean, np.nan)) all_cond_array.append(np.where(df['st'], last_cls_atten_mean, np.nan)) all_cond_array.append(np.where(df['qn'], last_cls_atten_mean, np.nan)) all_cond_array.append(np.where(df['de'], last_cls_atten_mean, np.nan)) all_cond_array.append(np.where(df['rkh'], last_cls_atten_mean, np.nan)) all_cond_array.append(np.where(df['wyf'], last_cls_atten_mean, np.nan)) all_cond_array.append(last_cls_atten_mean) all_cond_array = np.stack(all_cond_array) return { "sep_cls_for_plot": all_cond_array, "sep_feature_label": sep_feature_label } ###########post precess########### # ## 2.2 Get attention map # In[ ]: # usage case1: directly input custom sequence # seqinfo = SeqEncodeForPlot("custom","REPNQAFGSGNNSYSGSNSGAAIGWGSASNAGSGSGFNGGFGSSMDSKSSGWGM") # usage case2: by entering Uniprot Entry ID, # sending to PONDR by VSL2 to get disorder regions which fit the length criteria (>=40) # seqinfo = SeqEncodeForPlot("uniprot", "Q13148") # # 3. Plot feature maps # In[ ]: def plot(seqinfo): custom_color_scale = ['rgb(220,220,220)','rgb(255,243,59)','rgb(253,199,12)','rgb(243,144,63)','rgb(237,104,60)','rgb(233,62,58)'] fig = go.Figure(data=[go.Heatmap(hovertemplate='head: %{y}
aa: %{x}
value: %{z}', z=seqinfo.sep_cls_for_plot, x=seqinfo.plot_label, ygap = 1, y=seqinfo.sep_feature_label, colorscale=custom_color_scale, zmin=0,zmax=1 # )], # layout = Layout(paper_bgcolor='rgba(255,255,255,1)',plot_bgcolor='rgba(255,255,255,1)') #for image layout = Layout(paper_bgcolor='rgba(0,0,0,0)',plot_bgcolor='rgba(0,0,0,0)') #for html ) # layout fig.update_layout( title=seqinfo.title, width=seqinfo.seq_length*10, height=450, yaxis_showticklabels=True, yaxis = dict(tickfont=dict(size=12, color='black')), xaxis_showticklabels=True, xaxis_tickmode='linear', font=dict( size=8 )) #hover method fig.update_layout(hovermode='x unified') #no scale bar fig.update_traces(showscale=False) fig.update_layout(shapes=[ dict(type= 'line', yref= 'y', y0= i, y1= i, xref= 'x', x0= -0.5, x1= len(seqinfo.plot_label), line=dict( color='rgb(200, 200, 200)', width=0.5, dash="dash" ) ) for i in [0.5,1.5,2.5,3.5,4.5,5.5,6.5,7.5] ]) return fig # In[ ]: # plot(seqinfo) # In[ ]: