# initialize PONDR web crawler and sequence length function
pondr_driver = CutPONDR()
seqfilter = SeqFilter()
# attention score normalize function
def minmaxnormalize(data):
return (data - np.min(data)) / (np.max(data) - np.min(data))
# exception for no disorder region
class NoDisorderRegion(Exception):
pass
# attention map visualization
class SeqEncodeForPlot():
def __init__(self, input_type, sequence):
'''
input_type = ['uniprot','custom']
sequence = ['Q13148','AAAAAAAA']
'''
# get seq frag
if input_type == 'uniprot':
frag_out = self.__get_vis_seq_uid(sequence)
elif input_type == 'custom':
frag_out = self.__get_vis_seq_custom(sequence)
self.title = frag_out['title']
self.frag_seq = frag_out['frag_seq']
self.seq_length = len(frag_out['frag_seq'])
self.start_mod = frag_out['start_mod']
self.end_mod = frag_out['end_mod']
# get plot label
self.plot_label = self.__get_plot_label(self.frag_seq, self.start_mod, self.end_mod)
# encode
self.encode_seq = seqprocess.seq_process_pipe([self.frag_seq])
# get atten
self.atten = self.__get_atten_from_decorator(self.encode_seq)
# post process 1: get last cls atten mean
self.last_cls_atten_mean = self.__get_last_cls_atten_mean(self.atten)
# post process 2: cut cls
self.last_cls_atten_mean_cut = self.__cut_cls(self.last_cls_atten_mean, self.seq_length)
# post process 3: normalize
self.last_cls_atten_mean_cut_normalize = self.__normalize_cls(self.last_cls_atten_mean_cut)
# post process 4: feature sep
sep_out = self.__get_feature_sep(self.frag_seq, self.last_cls_atten_mean_cut_normalize)
self.sep_cls_for_plot = sep_out['sep_cls_for_plot']
self.sep_feature_label = sep_out['sep_feature_label']
###########get seq###########
def __get_vis_seq_custom(self, custom_sequence):
start_mod = 1
end_mod = len(custom_sequence)
return {
"title": '{}, {}~{}'.format('custom', start_mod, end_mod),
"frag_seq": custom_sequence,
"start_mod": start_mod,
"end_mod": end_mod
}
def __get_vis_seq_uid(self, uniprot_id):
# retriving sequence
seq_info = self.__get_sequence_online(uniprot_id, "VSL2", pondr_driver)
gene_name = seq_info['gene_name']
od_ident = seq_info['od_ident']
protein_sequence = seq_info['protein_sequence']
# od_ident length filter
od_ident = seqfilter.length_filter_by_od_ident(od_ident, disorder_filter_length=40, order_filter_length=10)
od_index = seqfilter.get_od_index(od_ident)['disorder_region']
if len(od_index) == 0:
raise NoDisorderRegion("This protein does not have disorder region")
elif len(od_index) == 1:
print("Only 1 disorder region {}, automatically use that".format(od_index[0]))
od_index_i = 0
else:
choose_hint = ''
for index, element in enumerate(od_index):
choose_hint = choose_hint + str("region {}: {}".format(index, element)) + "\n" + ' '
od_index_i = int(
input("Please choose disorder region: \n {}".format(choose_hint)))
# get protein seq by chosen index
start = od_index[od_index_i]['start']
end = od_index[od_index_i]['end']
if (end - start) > 512:
print("sequence length longer than 512: please specify the start and end")
start = int(input("start: \n"))
end = int(input("end: \n"))
frag_seq = protein_sequence[start:end]
# mod start index
start_mod = start + 1
end_mod = end
return {
"title":'{}, {}, {}~{}'.format(uniprot_id, gene_name, start_mod, end_mod),
"frag_seq":frag_seq,
"start_mod":start_mod,
"end_mod":end_mod
}
def __get_sequence_online(self, uniprot_id, od_ident_algorithm, pondr_driver):
# get seqeucne by uniprot id
a_protein = EbiAPI(uniprot_id)
# get sequence
protein_sequence = a_protein.protein_sequence
protein_sequence = seq_aa_check(protein_sequence)
gene_name = a_protein.gene_name
print('protein: {}, gene name: {}, length: {}'.format(uniprot_id, gene_name, len(protein_sequence)))
# use pondr to get od_ident
print("sending to PONDR by algorithm {}...".format(od_ident_algorithm))
pondr_driver.cut(protein_sequence, protein_name='aa', algorithm=od_ident_algorithm)
od_ident = pondr_driver.get_od_ident()
return {
"uniprot_id": uniprot_id,
"gene_name": gene_name,
"protein_sequence": protein_sequence,
"od_ident": od_ident
}
###########get seq###########
# lable plot
def __get_plot_label(self, frag_seq, start_mod, end_mod):
label = []
for index, element in enumerate(frag_seq):
label.append("{}_{}".format(index + start_mod, element))
return label
# get atten
def __get_atten_from_decorator(self, encoded_seq):
get_local.clear()
_ = base_encoder(encoded_seq, return_all_atten=True)[1]
# tidy attention_map from decorator to (layer, head, length, length)
atten = get_local.cache
atten = np.stack(atten['scaled_dot_product_attention']).squeeze()
atten = atten.reshape([num_heads, depth, atten.shape[-2], atten.shape[-2]]) # head, layer, length, length
atten = np.moveaxis(atten, [0, 1], [1, 0]) # layer, head, length, length
return torch.tensor(atten)
###########post precess###########
# get last cls atten
def __get_last_cls_atten_mean(self, atten):
last_cls_atten = atten[-1, :, 0, 1:]
last_cls_atten_mean = last_cls_atten.mean(dim=0, keepdim=False)
last_cls_atten_mean = last_cls_atten_mean.detach().numpy()
return last_cls_atten_mean
# cut cls
def __cut_cls(self, last_cls_atten_mean, seq_length):
last_cls_atten_mean = last_cls_atten_mean[:seq_length]
return last_cls_atten_mean
# normalize
def __normalize_cls(self, last_cls_atten_mean):
last_cls_atten_mean = minmaxnormalize(last_cls_atten_mean) #change normalize way
return last_cls_atten_mean
# feature sep
def __get_feature_sep(self, frag_seq, last_cls_atten_mean):
# sep feature label
sep_feature_label = [
"Hydrophobic (A I L M P V)",
'C',
'G',
'S T',
'Prion like (Q N)',
'Negative charge (D E)',
'Positive charge (R K H)',
'Aromatic (W Y F)',
'All features'
]
# many conditions
frag_seq = list(frag_seq)
df = pd.DataFrame(frag_seq, columns=['frag_seq'])
df['wyf'] = df['frag_seq'].apply(lambda x: x in ['W', 'Y', 'F'])
df['rkh'] = df['frag_seq'].apply(lambda x: x in ['R', 'K', 'H'])
df['de'] = df['frag_seq'].apply(lambda x: x in ['D', 'E'])
df['qn'] = df['frag_seq'].apply(lambda x: x in ['Q', 'N'])
df['st'] = df['frag_seq'].apply(lambda x: x in ['S', 'T'])
df['g'] = df['frag_seq'].apply(lambda x: x in ['G'])
df['c'] = df['frag_seq'].apply(lambda x: x in ['C'])
df['lpavmi'] = df['frag_seq'].apply(lambda x: x in ['L', 'P', 'A', 'V', 'M', 'I'])
all_cond_array = []
all_cond_array.append(np.where(df['lpavmi'], last_cls_atten_mean, np.nan))
all_cond_array.append(np.where(df['c'], last_cls_atten_mean, np.nan))
all_cond_array.append(np.where(df['g'], last_cls_atten_mean, np.nan))
all_cond_array.append(np.where(df['st'], last_cls_atten_mean, np.nan))
all_cond_array.append(np.where(df['qn'], last_cls_atten_mean, np.nan))
all_cond_array.append(np.where(df['de'], last_cls_atten_mean, np.nan))
all_cond_array.append(np.where(df['rkh'], last_cls_atten_mean, np.nan))
all_cond_array.append(np.where(df['wyf'], last_cls_atten_mean, np.nan))
all_cond_array.append(last_cls_atten_mean)
all_cond_array = np.stack(all_cond_array)
return {
"sep_cls_for_plot": all_cond_array,
"sep_feature_label": sep_feature_label
}
###########post precess###########