#!/usr/bin/env python
# coding: utf-8
# In[ ]:
import os
import sys
import copy
import requests
import numpy as np
import pandas as pd
from scipy import stats
from pathlib import Path
from tqdm.notebook import trange
from tqdm.notebook import tqdm
from urllib.request import urlretrieve
# In[ ]:
# git for functions loading and work path finding
import git
repo = git.Repo('.', search_parent_directories=True)
work_path = Path(repo.working_tree_dir)
if str(work_path) not in sys.path:
sys.path.append(str(work_path))
# In[ ]:
from function.seqfilter import SeqFilter
from function.utilities import seq_aa_check
from function.cutpondr import CutPONDR
from function.ebi import EbiAPI
# plot
import plotly
import plotly.graph_objects as go
from plotly.graph_objs import Layout
# attention map decorator
# https://github.com/luo3300612/Visualizer
from visualizer import get_local
get_local.activate()
# # 1. Load pretrain model
# In[ ]:
# download pretrained weight from OSF: https://osf.io/jk29b/
pretrain_model_path = work_path / 'trained_weight.pt'
if not pretrain_model_path.is_file():
url = 'https://osf.io/y2jh8/download'
urlretrieve(url, str(pretrain_model_path))
# In[ ]:
# load model architecture from 2_model_training.ipynb by IPython's magic command
pretrain_ipynb = str(work_path / '2_model_training.ipynb')
# In[ ]:
get_ipython().run_line_magic('run', '$pretrain_ipynb')
# In[ ]:
moco = moco_builder.MoCo(base_encoder=AttenTorchScratch, dim=embed_dim, mlp_dim=moco_mlp_dim, T=nce_temp)
moco.load_state_dict(torch.load(pretrain_model_path, map_location=torch.device('cpu')))
base_encoder = copy.deepcopy(moco.base_encoder)
base_encoder = base_encoder.eval()
# # 2. Inference
# ## 2.1 Function for attention map visualization
# In[ ]:
# initialize PONDR web crawler and sequence length function
pondr_driver = CutPONDR()
seqfilter = SeqFilter()
# attention score normalize function
def minmaxnormalize(data):
return (data - np.min(data)) / (np.max(data) - np.min(data))
# exception for no disorder region
class NoDisorderRegion(Exception):
pass
# attention map visualization
class SeqEncodeForPlot():
def __init__(self, input_type, sequence):
'''
input_type = ['uniprot','custom']
sequence = ['Q13148','AAAAAAAA']
'''
# get seq frag
if input_type == 'uniprot':
frag_out = self.__get_vis_seq_uid(sequence)
elif input_type == 'custom':
frag_out = self.__get_vis_seq_custom(sequence)
self.title = frag_out['title']
self.frag_seq = frag_out['frag_seq']
self.seq_length = len(frag_out['frag_seq'])
self.start_mod = frag_out['start_mod']
self.end_mod = frag_out['end_mod']
# get plot label
self.plot_label = self.__get_plot_label(self.frag_seq, self.start_mod, self.end_mod)
# encode
self.encode_seq = seqprocess.seq_process_pipe([self.frag_seq])
# get atten
self.atten = self.__get_atten_from_decorator(self.encode_seq)
# post process 1: get last cls atten mean
self.last_cls_atten_mean = self.__get_last_cls_atten_mean(self.atten)
# post process 2: cut cls
self.last_cls_atten_mean_cut = self.__cut_cls(self.last_cls_atten_mean, self.seq_length)
# post process 3: normalize
self.last_cls_atten_mean_cut_normalize = self.__normalize_cls(self.last_cls_atten_mean_cut)
# post process 4: feature sep
sep_out = self.__get_feature_sep(self.frag_seq, self.last_cls_atten_mean_cut_normalize)
self.sep_cls_for_plot = sep_out['sep_cls_for_plot']
self.sep_feature_label = sep_out['sep_feature_label']
###########get seq###########
def __get_vis_seq_custom(self, custom_sequence):
start_mod = 1
end_mod = len(custom_sequence)
return {
"title": '{}, {}~{}'.format('custom', start_mod, end_mod),
"frag_seq": custom_sequence,
"start_mod": start_mod,
"end_mod": end_mod
}
def __get_vis_seq_uid(self, uniprot_id):
# retriving sequence
seq_info = self.__get_sequence_online(uniprot_id, "VSL2", pondr_driver)
gene_name = seq_info['gene_name']
od_ident = seq_info['od_ident']
protein_sequence = seq_info['protein_sequence']
# od_ident length filter
od_ident = seqfilter.length_filter_by_od_ident(od_ident, disorder_filter_length=40, order_filter_length=10)
od_index = seqfilter.get_od_index(od_ident)['disorder_region']
if len(od_index) == 0:
raise NoDisorderRegion("This protein does not have disorder region")
elif len(od_index) == 1:
print("Only 1 disorder region {}, automatically use that".format(od_index[0]))
od_index_i = 0
else:
choose_hint = ''
for index, element in enumerate(od_index):
choose_hint = choose_hint + str("region {}: {}".format(index, element)) + "\n" + ' '
od_index_i = int(
input("Please choose disorder region: \n {}".format(choose_hint)))
# get protein seq by chosen index
start = od_index[od_index_i]['start']
end = od_index[od_index_i]['end']
if (end - start) > 512:
print("sequence length longer than 512: please specify the start and end")
start = int(input("start: \n"))
end = int(input("end: \n"))
frag_seq = protein_sequence[start:end]
# mod start index
start_mod = start + 1
end_mod = end
return {
"title":'{}, {}, {}~{}'.format(uniprot_id, gene_name, start_mod, end_mod),
"frag_seq":frag_seq,
"start_mod":start_mod,
"end_mod":end_mod
}
def __get_sequence_online(self, uniprot_id, od_ident_algorithm, pondr_driver):
# get seqeucne by uniprot id
a_protein = EbiAPI(uniprot_id)
# get sequence
protein_sequence = a_protein.protein_sequence
protein_sequence = seq_aa_check(protein_sequence)
gene_name = a_protein.gene_name
print('protein: {}, gene name: {}, length: {}'.format(uniprot_id, gene_name, len(protein_sequence)))
# use pondr to get od_ident
print("sending to PONDR by algorithm {}...".format(od_ident_algorithm))
pondr_driver.cut(protein_sequence, protein_name='aa', algorithm=od_ident_algorithm)
od_ident = pondr_driver.get_od_ident()
return {
"uniprot_id": uniprot_id,
"gene_name": gene_name,
"protein_sequence": protein_sequence,
"od_ident": od_ident
}
###########get seq###########
# lable plot
def __get_plot_label(self, frag_seq, start_mod, end_mod):
label = []
for index, element in enumerate(frag_seq):
label.append("{}_{}".format(index + start_mod, element))
return label
# get atten
def __get_atten_from_decorator(self, encoded_seq):
get_local.clear()
_ = base_encoder(encoded_seq, return_all_atten=True)[1]
# tidy attention_map from decorator to (layer, head, length, length)
atten = get_local.cache
atten = np.stack(atten['scaled_dot_product_attention']).squeeze()
atten = atten.reshape([num_heads, depth, atten.shape[-2], atten.shape[-2]]) # head, layer, length, length
atten = np.moveaxis(atten, [0, 1], [1, 0]) # layer, head, length, length
return torch.tensor(atten)
###########post precess###########
# get last cls atten
def __get_last_cls_atten_mean(self, atten):
last_cls_atten = atten[-1, :, 0, 1:]
last_cls_atten_mean = last_cls_atten.mean(dim=0, keepdim=False)
last_cls_atten_mean = last_cls_atten_mean.detach().numpy()
return last_cls_atten_mean
# cut cls
def __cut_cls(self, last_cls_atten_mean, seq_length):
last_cls_atten_mean = last_cls_atten_mean[:seq_length]
return last_cls_atten_mean
# normalize
def __normalize_cls(self, last_cls_atten_mean):
last_cls_atten_mean = minmaxnormalize(last_cls_atten_mean) #change normalize way
return last_cls_atten_mean
# feature sep
def __get_feature_sep(self, frag_seq, last_cls_atten_mean):
# sep feature label
sep_feature_label = [
"Hydrophobic (A I L M P V)",
'C',
'G',
'S T',
'Prion like (Q N)',
'Negative charge (D E)',
'Positive charge (R K H)',
'Aromatic (W Y F)',
'All features'
]
# many conditions
frag_seq = list(frag_seq)
df = pd.DataFrame(frag_seq, columns=['frag_seq'])
df['wyf'] = df['frag_seq'].apply(lambda x: x in ['W', 'Y', 'F'])
df['rkh'] = df['frag_seq'].apply(lambda x: x in ['R', 'K', 'H'])
df['de'] = df['frag_seq'].apply(lambda x: x in ['D', 'E'])
df['qn'] = df['frag_seq'].apply(lambda x: x in ['Q', 'N'])
df['st'] = df['frag_seq'].apply(lambda x: x in ['S', 'T'])
df['g'] = df['frag_seq'].apply(lambda x: x in ['G'])
df['c'] = df['frag_seq'].apply(lambda x: x in ['C'])
df['lpavmi'] = df['frag_seq'].apply(lambda x: x in ['L', 'P', 'A', 'V', 'M', 'I'])
all_cond_array = []
all_cond_array.append(np.where(df['lpavmi'], last_cls_atten_mean, np.nan))
all_cond_array.append(np.where(df['c'], last_cls_atten_mean, np.nan))
all_cond_array.append(np.where(df['g'], last_cls_atten_mean, np.nan))
all_cond_array.append(np.where(df['st'], last_cls_atten_mean, np.nan))
all_cond_array.append(np.where(df['qn'], last_cls_atten_mean, np.nan))
all_cond_array.append(np.where(df['de'], last_cls_atten_mean, np.nan))
all_cond_array.append(np.where(df['rkh'], last_cls_atten_mean, np.nan))
all_cond_array.append(np.where(df['wyf'], last_cls_atten_mean, np.nan))
all_cond_array.append(last_cls_atten_mean)
all_cond_array = np.stack(all_cond_array)
return {
"sep_cls_for_plot": all_cond_array,
"sep_feature_label": sep_feature_label
}
###########post precess###########
# ## 2.2 Get attention map
# In[ ]:
# usage case1: directly input custom sequence
# seqinfo = SeqEncodeForPlot("custom","REPNQAFGSGNNSYSGSNSGAAIGWGSASNAGSGSGFNGGFGSSMDSKSSGWGM")
# usage case2: by entering Uniprot Entry ID,
# sending to PONDR by VSL2 to get disorder regions which fit the length criteria (>=40)
# seqinfo = SeqEncodeForPlot("uniprot", "Q13148")
# # 3. Plot feature maps
# In[ ]:
def plot(seqinfo):
custom_color_scale = ['rgb(220,220,220)','rgb(255,243,59)','rgb(253,199,12)','rgb(243,144,63)','rgb(237,104,60)','rgb(233,62,58)']
fig = go.Figure(data=[go.Heatmap(hovertemplate='head: %{y}
aa: %{x}
value: %{z}',
z=seqinfo.sep_cls_for_plot,
x=seqinfo.plot_label,
ygap = 1,
y=seqinfo.sep_feature_label,
colorscale=custom_color_scale,
zmin=0,zmax=1 #
)],
# layout = Layout(paper_bgcolor='rgba(255,255,255,1)',plot_bgcolor='rgba(255,255,255,1)') #for image
layout = Layout(paper_bgcolor='rgba(0,0,0,0)',plot_bgcolor='rgba(0,0,0,0)') #for html
)
# layout
fig.update_layout(
title=seqinfo.title,
width=seqinfo.seq_length*10,
height=450,
yaxis_showticklabels=True,
yaxis = dict(tickfont=dict(size=12, color='black')),
xaxis_showticklabels=True,
xaxis_tickmode='linear',
font=dict(
size=8
))
#hover method
fig.update_layout(hovermode='x unified')
#no scale bar
fig.update_traces(showscale=False)
fig.update_layout(shapes=[
dict(type= 'line',
yref= 'y', y0= i, y1= i,
xref= 'x', x0= -0.5, x1= len(seqinfo.plot_label),
line=dict(
color='rgb(200, 200, 200)',
width=0.5,
dash="dash"
)
) for i in [0.5,1.5,2.5,3.5,4.5,5.5,6.5,7.5] ])
return fig
# In[ ]:
# plot(seqinfo)
# In[ ]: