import collections
import os
import json
import logging
import pandas as pd
import numpy as np
import scipy
import tensorflow as tf
import matplotlib.pyplot as plt
import seaborn as sns
from tensorflow.keras.models import load_model
import statsmodels.api as sm
from sqlalchemy import create_engine
if os.getcwd().endswith('notebook'):
os.chdir('..')
from rna_learn.alphabet import ALPHABET_DNA
from rna_learn.load_sequences import (
load_growth_temperatures,
compute_inverse_effective_sample,
assign_weight_to_batch_values,
SpeciesSequence,
load_batch_dataframe,
)
from rna_learn.transform import sequence_embedding, normalize, denormalize
from rna_learn.model import conv1d_densenet_regression_model, compile_regression_model, DenormalizedMAE
from rna_learn.int_grads import (
integrated_gradients_for_binary_features,
)
from rna_learn.fourier import compute_autocorr_fft
sns.set(palette='colorblind', font_scale=1.3)
palette = sns.color_palette()
logging.basicConfig(level=logging.INFO, format="%(asctime)s (%(levelname)s) %(message)s")
db_path = os.path.join(os.getcwd(), 'data/condensed_traits/db/seq.db')
engine = create_engine(f'sqlite+pysqlite:///{db_path}')
run_id = 'run_yb64o'
model_path = os.path.join(os.getcwd(), f'saved_models/{run_id}/model.h5')
metadata_path = os.path.join(os.getcwd(), f'saved_models/{run_id}/metadata.json')
validation_csv = os.path.join(os.getcwd(), f'saved_models/{run_id}/validation.csv')
val_df = pd.read_csv(validation_csv).set_index('species_taxid')
val_df.head(2)
growth_tmp_actual | growth_tmp_prediction | growth_tmp_std | |
---|---|---|---|
species_taxid | |||
7 | 30.00 | 36.67 | 15.47 |
14 | 74.15 | 62.08 | 24.86 |
with open(metadata_path) as f:
metadata = json.load(f)
model = conv1d_densenet_regression_model(
alphabet_size=len(metadata['alphabet']),
growth_rate=metadata['growth_rate'],
n_layers=metadata['n_layers'],
kernel_sizes=metadata['kernel_sizes'],
dilation_rates=metadata['dilation_rates'],
l2_reg=metadata['l2_reg'],
dropout=metadata['dropout'],
)
model.summary()
Model: "functional_1" __________________________________________________________________________________________________ Layer (type) Output Shape Param # Connected to ================================================================================================== sequence (InputLayer) [(None, None, 4)] 0 __________________________________________________________________________________________________ conv_1 (Conv1D) (None, None, 15) 195 sequence[0][0] __________________________________________________________________________________________________ concat_1 (Concatenate) (None, None, 19) 0 sequence[0][0] conv_1[0][0] __________________________________________________________________________________________________ conv_2 (Conv1D) (None, None, 15) 1440 concat_1[0][0] __________________________________________________________________________________________________ concat_2 (Concatenate) (None, None, 34) 0 concat_1[0][0] conv_2[0][0] __________________________________________________________________________________________________ conv_3 (Conv1D) (None, None, 15) 2565 concat_2[0][0] __________________________________________________________________________________________________ concat_3 (Concatenate) (None, None, 49) 0 concat_2[0][0] conv_3[0][0] __________________________________________________________________________________________________ conv_4 (Conv1D) (None, None, 15) 3690 concat_3[0][0] __________________________________________________________________________________________________ concat_4 (Concatenate) (None, None, 64) 0 concat_3[0][0] conv_4[0][0] __________________________________________________________________________________________________ conv_5 (Conv1D) (None, None, 15) 4815 concat_4[0][0] __________________________________________________________________________________________________ concat_5 (Concatenate) (None, None, 79) 0 concat_4[0][0] conv_5[0][0] __________________________________________________________________________________________________ conv_6 (Conv1D) (None, None, 15) 5940 concat_5[0][0] __________________________________________________________________________________________________ concat_6 (Concatenate) (None, None, 94) 0 concat_5[0][0] conv_6[0][0] __________________________________________________________________________________________________ conv_7 (Conv1D) (None, None, 15) 7065 concat_6[0][0] __________________________________________________________________________________________________ concat_7 (Concatenate) (None, None, 109) 0 concat_6[0][0] conv_7[0][0] __________________________________________________________________________________________________ conv_8 (Conv1D) (None, None, 15) 8190 concat_7[0][0] __________________________________________________________________________________________________ concat_8 (Concatenate) (None, None, 124) 0 concat_7[0][0] conv_8[0][0] __________________________________________________________________________________________________ conv_9 (Conv1D) (None, None, 15) 9315 concat_8[0][0] __________________________________________________________________________________________________ concat_9 (Concatenate) (None, None, 139) 0 concat_8[0][0] conv_9[0][0] __________________________________________________________________________________________________ conv_10 (Conv1D) (None, None, 15) 10440 concat_9[0][0] __________________________________________________________________________________________________ tf_op_layer_NotEqual (TensorFlo [(None, None, 4)] 0 sequence[0][0] __________________________________________________________________________________________________ concat_10 (Concatenate) (None, None, 154) 0 concat_9[0][0] conv_10[0][0] __________________________________________________________________________________________________ tf_op_layer_Any (TensorFlowOpLa [(None, None)] 0 tf_op_layer_NotEqual[0][0] __________________________________________________________________________________________________ logits (GlobalAveragePooling1D) (None, 154) 0 concat_10[0][0] tf_op_layer_Any[0][0] __________________________________________________________________________________________________ dropout (Dropout) (None, 154) 0 logits[0][0] __________________________________________________________________________________________________ mean_and_std (Dense) (None, 2) 310 dropout[0][0] __________________________________________________________________________________________________ independent_normal (Independent ((None, 1), (None, 1 0 mean_and_std[0][0] ================================================================================================== Total params: 53,965 Trainable params: 53,965 Non-trainable params: 0 __________________________________________________________________________________________________
model.load_weights(model_path)
temperatures, mean, std = load_growth_temperatures(engine)
species_traits_query = """
select * from species_traits where species_taxid in (
select species_taxid from train_test_split
where in_test_set = 1
)
"""
species_traits = pd.read_sql(species_traits_query, engine)
species_traits.head(2)
species_taxid | species | genus | family | order | class | phylum | superkingdom | gram_stain | metabolism | ... | genome_size.stdev | gc_content.stdev | coding_genes.stdev | optimum_tmp.stdev | optimum_ph.stdev | growth_tmp.stdev | rRNA16S_genes.stdev | tRNA_genes.stdev | data_source | ref_id | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
0 | 7 | Azorhizobium caulinodans | Azorhizobium | Xanthobacteraceae | Rhizobiales | Alphaproteobacteria | Proteobacteria | Bacteria | negative | aerobic | ... | 1.000 | NaN | 7.572 | NaN | NaN | NaN | 0.0 | 0.0 | engqvist, genbank, gold, jemma-refseq, kegg, p... | 705, 6102, 13521, 13643, 15838, 19908, 19956 |
1 | 14 | Dictyoglomus thermophilum | Dictyoglomus | Dictyoglomaceae | Dictyoglomales | Dictyoglomia | Dictyoglomi | Bacteria | negative | anaerobic | ... | 1.342 | 0.0 | 31.086 | 0.0 | NaN | 5.869 | 0.0 | 0.0 | corkrey, engqvist, genbank, gold, jemma-refseq... | 436, 705, 6230, 13521, 14083, 16003, 19891, 19... |
2 rows × 79 columns
species_traits[species_traits['growth_tmp'] == 37].head(2)
species_taxid | species | genus | family | order | class | phylum | superkingdom | gram_stain | metabolism | ... | genome_size.stdev | gc_content.stdev | coding_genes.stdev | optimum_tmp.stdev | optimum_ph.stdev | growth_tmp.stdev | rRNA16S_genes.stdev | tRNA_genes.stdev | data_source | ref_id | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
9 | 167 | Treponema succinifaciens | Treponema | Spirochaetaceae | Spirochaetales | Spirochaetia | Spirochaetes | Bacteria | negative | anaerobic | ... | 2.236 | 0.284 | 37.873 | NaN | NaN | NaN | 0.0 | 0.000 | engqvist, faprotax, genbank, gold, jemma-refse... | 705, 999, 6514, 13521, 15182, 16567, 19891, 19... |
10 | 196 | Campylobacter fetus | Campylobacter | Campylobacteraceae | Campylobacterales | Epsilonproteobacteria | Proteobacteria | Bacteria | negative | microaerophilic | ... | 102299.821 | 0.766 | 124.603 | 0.0 | NaN | NaN | 0.0 | 1.477 | bacdive-microa, engqvist, genbank, gold, jemma... | 151, 705, 5660, 9060, 9205, 9820, 9954, 10274,... |
2 rows × 79 columns
query = """
select t.species, s.sequence, s.length, t.growth_tmp
from sequences as s
inner join species_traits as t
on t.species_taxid = s.species_taxid
where s.species_taxid = ?
"""
species_taxid = 167
species_df = pd.read_sql(query, engine, params=(species_taxid,))
species = species_df['species'].iloc[0]
actual_tmp = species_df['growth_tmp'].iloc[0]
species, actual_tmp
('Treponema succinifaciens', 37.0)
val_df.loc[species_taxid]
growth_tmp_actual 37.00 growth_tmp_prediction 42.41 growth_tmp_std 20.00 Name: 167, dtype: float64
len(species_df)
3820
max_sequence_length = 9999 # divisible by 3
species_seq = SpeciesSequence(
engine,
species_taxid=species_taxid,
batch_size=64,
temperatures=temperatures,
mean=mean,
std=std,
alphabet=ALPHABET_DNA,
max_sequence_length=max_sequence_length,
random_seed=metadata['seed'],
)
def load_sequence_row(rowid):
q = """
select * from sequences where rowid = ?
"""
return pd.read_sql(q, engine, params=(int(rowid),)).iloc[0]
batch_x, batch_y, _ = species_seq[0]
inputs = batch_x[0:10]
baseline = np.ones(inputs[0].shape, dtype='float32')
target = batch_y[0]
attributions = integrated_gradients_for_binary_features(
model,
inputs,
baseline,
target,
).numpy()
attributions.shape
(10, 999)
sequence_row = load_sequence_row(species_seq.rowids[0,0])
length = sequence_row.length
sequence_row
sequence_id lcl|CP002631.1_cds_AEB15151.1_2181 species_taxid 167 sequence_type CDS chromosome_id CP002631.1 location_json [[2439987, 2440718]] strand + length 732 description lcl|CP002631.1_cds_AEB15151.1_2181 [locus_tag=... metadata_json {"protein": "uridylate kinase", "protein_id": ... sequence ATGACTACAAAAGTTTTAAGCGTAGGCGGTTCAATAATTGCGCCTG... Name: 0, dtype: object
sequence_row.metadata_json
'{"protein": "uridylate kinase", "protein_id": "AEB15151.1"}'
attributions = attributions[0, :length]
percentile = 90
threshold = np.percentile(attributions, percentile)
inv_threshold = np.percentile(attributions, 100 - percentile)
_, ax = plt.subplots(1, 1, figsize=(12, 6))
ax.hist(attributions, bins=50);
ax.axvline(threshold, color='red');
ax.axvline(inv_threshold, color='red');
attributions_c = attributions.copy()
threshold_indices = attributions_c < threshold
attributions_c[threshold_indices] = 0
_, ax = plt.subplots(1, 1, figsize=(20, 4))
sns.heatmap(attributions_c[np.newaxis,:length], vmin=0, ax=ax, cmap='Greys');
ax.get_yaxis().set_visible(False);
attributions_ = attributions.copy()
threshold_ix = attributions_ < threshold
inv_threshold_ix = attributions_ > inv_threshold
attributions_[threshold_ix & inv_threshold_ix] = 0
_, ax = plt.subplots(1, 1, figsize=(15, 6))
ax.plot(range(length), attributions, '-', label='Raw');
ax.plot(range(length), attributions_, '-', label='Thresholding');
ax.set_title('Attributions');
ax.legend();
def autocorr(x):
result = np.correlate(x, x, mode='full')
return result[result.size//2 + 1:]
_, ax = plt.subplots(1, 1, figsize=(15, 6))
attr_autocorr = autocorr(attributions)
ax.plot(range(len(attr_autocorr)), attr_autocorr);
attr_autocorr_ = autocorr(attributions_)
ax.plot(range(len(attr_autocorr_)), attr_autocorr_);
ft = np.fft.rfft(attr_autocorr, n=20, norm='ortho')
ft_real = np.abs(ft)
ft_imag = np.imag(ft)
_, ax = plt.subplots(1, 1, figsize=(15, 6))
ax.plot(range(1, len(ft) + 1), ft_real / length, label='Real');
#ax.plot(range(1, len(ft) + 1), ft_imag / length, label='Imag');
ax.legend();
ft = np.fft.rfft(attr_autocorr, norm='ortho')
ft_real = np.abs(ft)
ft_imag = np.imag(ft)
_, ax = plt.subplots(1, 1, figsize=(15, 6))
ax.plot(range(1, len(ft) + 1), ft_real, label='Real');
ax.legend();
(np.argsort(-np.abs(ft)) + 1)[:20]
array([101, 175, 155, 156, 164, 102, 154, 71, 171, 170, 107, 172, 168, 176, 143, 77, 41, 180, 64, 108])
def compute_attributions_and_fft(species_seq, actual_tmp, n=200):
all_attributions = []
all_ffts = []
max_length = species_seq.max_sequence_length
for i in range(len(species_seq)):
if (i + 1) % 10 == 0:
print(f'Batch {i+1} / {len(species_seq)}')
batch_x, batch_y, _ = species_seq[i]
baseline = np.ones(batch_x[0].shape, dtype='float32')
target = batch_y[0]
attributions = integrated_gradients_for_binary_features(
model,
batch_x,
baseline,
target,
).numpy()
a = i * species_seq.batch_size
b = (i + 1) * species_seq.batch_size
batch_rowids = species_seq.rowids[a:b]
batch_df = load_batch_dataframe(species_seq.engine, batch_rowids[:, 0])
for j, seq_attribution_padded in enumerate(attributions):
rowid, part_ix = batch_rowids[j]
length = int(batch_df.loc[rowid]['length'])
if length > max_length:
a = part_ix * max_length
b = (part_ix + 1) * max_length
if b > length:
length = length - part_ix * max_length
else:
length = max_length
seq_attribution = seq_attribution_padded[:length]
percentile = 90
threshold = np.percentile(seq_attribution, percentile)
inv_threshold = np.percentile(seq_attribution, 100 - percentile)
seq_attribution_ = seq_attribution.copy()
threshold_ix = seq_attribution < threshold
inv_threshold_ix = seq_attribution > inv_threshold
seq_attribution_[threshold_ix & inv_threshold_ix] = 0
real_fft = compute_autocorr_fft(seq_attribution, n=None)
all_ffts.append(real_fft[:n // 2])
return all_attributions, all_ffts
#%%time
#all_attributions, all_ffts = compute_attributions_and_fft(species_seq, actual_tmp, n=200)
#best_freqs = [np.argmax(fft) + 1 for fft in all_ffts]
#plt.hist(best_freqs, bins=100);
#mean_fft = np.mean(all_ffts, axis=0)
#
#_, ax = plt.subplots(1, 1, figsize=(15, 6))
#ax.plot(range(1, len(mean_fft) + 1), mean_fft, 'o');
def compute_attributions_per_codon(species_seq, actual_tmp):
attributions_per_codon = collections.defaultdict(list)
max_length = species_seq.max_sequence_length
for i in range(len(species_seq)):
if (i + 1) % 10 == 0:
print(f'Batch {i+1} / {len(species_seq)}')
batch_x, batch_y, _ = species_seq[i]
baseline = np.ones(batch_x[0].shape, dtype='float32')
target = batch_y[0]
attributions = integrated_gradients_for_binary_features(
model,
batch_x,
baseline,
target,
).numpy()
a = i * species_seq.batch_size
b = (i + 1) * species_seq.batch_size
batch_rowids = species_seq.rowids[a:b]
batch_df = load_batch_dataframe(species_seq.engine, batch_rowids[:, 0])
for j, seq_attribution_padded in enumerate(attributions):
rowid, part_ix = batch_rowids[j]
seq_metadata = batch_df.loc[rowid]
sequence_full = seq_metadata['sequence']
length = int(seq_metadata['length'])
sequence_type = seq_metadata['sequence_type']
if sequence_type != 'CDS':
continue
if length > max_length:
a = part_ix * max_length
b = (part_ix + 1) * max_length
if b > length:
length = length - part_ix * max_length
else:
length = max_length
sequence = sequence_full[a:b]
else:
sequence = sequence_full
seq_attribution = seq_attribution_padded[:length]
assert len(sequence) % 3 == 0
assert len(sequence) == len(seq_attribution), f"{len(sequence)}, {len(seq_attribution)}, {int(seq_metadata['length'])}"
for codon_start_ix in range(0, len(seq_attribution), 3):
codon_end_ix = codon_start_ix + 3
codon = sequence[codon_start_ix:codon_end_ix]
codon_attr = seq_attribution[codon_start_ix:codon_end_ix]
attributions_per_codon[codon].append(codon_attr)
for key in attributions_per_codon.keys():
attributions_per_codon[key] = np.array(attributions_per_codon[key])
return dict(attributions_per_codon)
%%time
attributions_per_codon = compute_attributions_per_codon(species_seq, actual_tmp)
Batch 10 / 60 Batch 20 / 60 Batch 30 / 60 Batch 40 / 60 Batch 50 / 60 Batch 60 / 60 CPU times: user 2min 46s, sys: 1min 24s, total: 4min 11s Wall time: 59.5 s
mean_attribution_per_codon = []
for key in attributions_per_codon.keys():
attr = attributions_per_codon[key]
m = np.mean(attr)
s = np.std(attr)
cod = [np.round(n, 4) for n in np.mean(attributions_per_codon[key], axis=0)]
mean_attribution_per_codon.append((key, m, s, len(attr), cod))
sorted_mean_attribution_per_codon = sorted(mean_attribution_per_codon, key=lambda t: t[1], reverse=True)
sorted_mean_attribution_per_codon
[('CTG', 0.0012269845, 0.0069986414, 7173, [0.0008, 0.0005, 0.0024]), ('CAG', 0.0012047703, 0.008702724, 16926, [0.0014, 0.0011, 0.0011]), ('ATC', 0.0011405131, 0.00880471, 12313, [0.0014, 0.0011, 0.0009]), ('GAT', 0.0010472819, 0.008814094, 24036, [0.0006, 0.0015, 0.001]), ('CGC', 0.0009866252, 0.007033925, 7872, [0.001, 0.002, -1e-04]), ('TGC', 0.00093589147, 0.0072951447, 8706, [0.0003, 0.0027, -0.0002]), ('TCG', 0.00085053814, 0.00638489, 5431, [0.0005, 0.0011, 0.0009]), ('CGA', 0.0007245949, 0.006480547, 1840, [0.0008, 0.0002, 0.0011]), ('TTC', 0.0007185005, 0.008449351, 12764, [0.001, 0.0014, -0.0002]), ('TCA', 0.0006722864, 0.008287686, 16388, [0.0003, 0.0015, 0.0002]), ('AAT', 0.000666689, 0.0073139067, 27536, [1e-04, 0.0011, 0.0008]), ('GCG', 0.0005815432, 0.0069153737, 12838, [0.001, 0.0013, -0.0006]), ('AGT', 0.0005803461, 0.007288064, 5464, [1e-04, 0.0009, 0.0007]), ('CAT', 0.0005356391, 0.007231406, 7254, [0.0006, 0.0019, -0.0009]), ('ATG', 0.00051723054, 0.007417146, 18833, [0.0007, 0.0007, 1e-04]), ('CGT', 0.0005130675, 0.010666447, 5436, [0.0003, 0.0014, -1e-04]), ('GGT', 0.0004950146, 0.007122442, 7676, [-1e-04, 0.0007, 0.0009]), ('CAC', 0.00041939082, 0.0102369245, 4954, [0.0011, 0.0022, -0.002]), ('CAA', 0.00041934755, 0.0063995505, 8200, [0.0005, 1e-04, 0.0006]), ('ACC', 0.0003379698, 0.01319075, 3656, [0.0012, 0.0006, -0.0008]), ('GCA', 0.00030993277, 0.008544098, 21553, [-0.0004, 0.002, -0.0007]), ('AGC', 0.0002998688, 0.0078124106, 12850, [-0.0012, 0.0007, 0.0014]), ('GAC', 0.00028222834, 0.008492842, 21074, [0.0002, -0.0006, 0.0012]), ('ACT', 0.00028091006, 0.0070561096, 13660, [0.0009, 0.0005, -0.0006]), ('GGC', 0.00026411083, 0.007373901, 13480, [-0.001, 0.0005, 0.0013]), ('ACA', 0.00011927956, 0.010216563, 19515, [-0.0016, 0.0033, -0.0014]), ('TGG', 5.4014123e-05, 0.007060396, 8082, [0.0013, 0.0005, -0.0016]), ('GAA', 3.5069737e-05, 0.008439695, 49059, [-0.0017, 0.0015, 0.0003]), ('TGT', -1.4546423e-05, 0.010232731, 3885, [-1e-04, 0.0016, -0.0015]), ('CGG', -1.7114728e-05, 0.007660584, 2725, [0.0021, -0.0005, -0.0016]), ('CCG', -2.187285e-05, 0.006419064, 9126, [-0.0007, -0.0007, 0.0013]), ('AAC', -3.443031e-05, 0.007988626, 17898, [0.0003, -0.0004, -1e-04]), ('AAA', -5.6715922e-05, 0.008365504, 51297, [-0.0009, 0.0011, -0.0003]), ('TTG', -7.4529526e-05, 0.006114404, 13150, [0.0002, -0.0003, -1e-04]), ('AAG', -0.00013385444, 0.00867859, 19895, [-1e-04, 0.0006, -0.0009]), ('TTA', -0.00015087331, 0.008606512, 9853, [0.0007, -0.0002, -0.0009]), ('ATT', -0.0001861387, 0.008588004, 35854, [0.0009, -0.0005, -0.0009]), ('TTT', -0.00019083632, 0.0074606487, 32359, [0.001, -1e-04, -0.0015]), ('ACG', -0.00028037513, 0.008027939, 4945, [-0.0004, 0.001, -0.0014]), ('GCT', -0.00038770924, 0.008479666, 15256, [0.0, -0.0002, -0.001]), ('GCC', -0.000482065, 0.008810254, 8685, [0.0004, -0.0008, -0.001]), ('GTC', -0.00050598057, 0.011955925, 5967, [0.0006, -0.0021, -0.0]), ('TCT', -0.00052484474, 0.009645975, 15786, [-0.0002, 0.0004, -0.0017]), ('TAT', -0.000538721, 0.008289555, 18963, [-0.0014, 0.0005, -0.0007]), ('TCC', -0.00056565495, 0.009113393, 6569, [1e-04, -0.0, -0.0017]), ('AGA', -0.00064059685, 0.010341094, 10428, [-0.0028, 0.0026, -0.0017]), ('CTA', -0.0010165717, 0.013644065, 3094, [0.0006, -0.0035, -1e-04]), ('GGA', -0.0010713548, 0.008167442, 27899, [-0.0027, -0.0002, -0.0004]), ('CCA', -0.0011128361, 0.008666699, 7306, [-0.001, -0.0022, -1e-04]), ('GTA', -0.0011280739, 0.009622252, 12114, [-0.0003, -0.0005, -0.0026]), ('GGG', -0.0011332964, 0.0057915323, 3044, [-0.0013, -0.0004, -0.0017]), ('GTG', -0.0011828399, 0.010704376, 6477, [-0.003, -0.0011, 0.0006]), ('CCT', -0.0012527367, 0.009854893, 9925, [-0.0012, -0.0016, -0.001]), ('CTT', -0.0012695498, 0.008606616, 35552, [-0.0004, -0.0023, -0.0011]), ('GTT', -0.0013079934, 0.009266828, 26912, [0.0002, -0.0026, -0.0015]), ('CCC', -0.0014644511, 0.008015597, 1459, [-0.0012, -0.0015, -0.0017]), ('TAC', -0.0019707228, 0.009619422, 12051, [-0.0024, -0.002, -0.0015]), ('GAG', -0.0020935615, 0.012976261, 13878, [-0.0037, -0.0005, -0.0021]), ('AGG', -0.0021315727, 0.011015734, 3593, [-0.0022, -0.0005, -0.0038]), ('ATA', -0.002500575, 0.012337817, 16217, [-0.0014, -0.0014, -0.0047]), ('TAA', -0.0032592283, 0.0045169652, 1484, [-0.0018, -0.0023, -0.0058]), ('CTC', -0.0034883053, 0.012235202, 4711, [-0.0049, -0.0029, -0.0027]), ('TAG', -0.003704215, 0.005706313, 517, [-0.0012, -0.0056, -0.0043]), ('TGA', -0.004366006, 0.0078029274, 607, [-0.0011, -0.007, -0.0049])]
_, ax = plt.subplots(1, 1, figsize=(12, 6))
ax.hist(np.mean(attributions_per_codon['CTC'], axis=1), bins=50);