import collections
import os
import json
import logging
import pandas as pd
import numpy as np
import scipy
import tensorflow as tf
import matplotlib.pyplot as plt
import seaborn as sns
from tensorflow.keras.models import load_model
import statsmodels.api as sm
from statsmodels.sandbox.regression.predstd import wls_prediction_std
from sqlalchemy import create_engine
if os.getcwd().endswith('notebook'):
os.chdir('..')
from rna_learn.alphabet import ALPHABET_DNA
from rna_learn.load_sequences import (
load_growth_temperatures,
compute_inverse_effective_sample,
assign_weight_to_batch_values,
SpeciesSequence,
)
from rna_learn.transform import sequence_embedding, normalize, denormalize
from rna_learn.model import conv1d_densenet_regression_model, compile_regression_model, DenormalizedMAE
sns.set(palette='colorblind', font_scale=1.3)
palette = sns.color_palette()
logging.basicConfig(level=logging.INFO, format="%(asctime)s (%(levelname)s) %(message)s")
db_path = os.path.join(os.getcwd(), 'data/condensed_traits/db/seq.db')
engine = create_engine(f'sqlite+pysqlite:///{db_path}')
run_id = 'run_yb64o'
model_path = os.path.join(os.getcwd(), f'saved_models/{run_id}/model.h5')
metadata_path = os.path.join(os.getcwd(), f'saved_models/{run_id}/metadata.json')
validation_csv = os.path.join(os.getcwd(), f'saved_models/{run_id}/validation.csv')
with open(metadata_path) as f:
metadata = json.load(f)
metadata
{'run_id': 'run_yb64o', 'alphabet': ['A', 'C', 'G', 'T'], 'learning_rate': 0.0005, 'batch_size': 64, 'encoding_size': 20, 'decoder_n_hidden': 100, 'growth_rate': 15, 'n_layers': 10, 'kernel_sizes': [3, 5, 5, 5, 5, 5, 5, 5, 5, 5], 'strides': None, 'dilation_rates': None, 'l2_reg': 1e-05, 'dropout': 0.5, 'n_epochs': 10, 'max_sequence_length': 5000, 'seed': 28, 'val_loss': 1.5334341526031494}
model = conv1d_densenet_regression_model(
alphabet_size=len(metadata['alphabet']),
growth_rate=metadata['growth_rate'],
n_layers=metadata['n_layers'],
kernel_sizes=metadata['kernel_sizes'],
dilation_rates=metadata['dilation_rates'],
l2_reg=metadata['l2_reg'],
dropout=metadata['dropout'],
)
model.summary()
Model: "functional_1" __________________________________________________________________________________________________ Layer (type) Output Shape Param # Connected to ================================================================================================== sequence (InputLayer) [(None, None, 4)] 0 __________________________________________________________________________________________________ conv_1 (Conv1D) (None, None, 15) 195 sequence[0][0] __________________________________________________________________________________________________ concat_1 (Concatenate) (None, None, 19) 0 sequence[0][0] conv_1[0][0] __________________________________________________________________________________________________ conv_2 (Conv1D) (None, None, 15) 1440 concat_1[0][0] __________________________________________________________________________________________________ concat_2 (Concatenate) (None, None, 34) 0 concat_1[0][0] conv_2[0][0] __________________________________________________________________________________________________ conv_3 (Conv1D) (None, None, 15) 2565 concat_2[0][0] __________________________________________________________________________________________________ concat_3 (Concatenate) (None, None, 49) 0 concat_2[0][0] conv_3[0][0] __________________________________________________________________________________________________ conv_4 (Conv1D) (None, None, 15) 3690 concat_3[0][0] __________________________________________________________________________________________________ concat_4 (Concatenate) (None, None, 64) 0 concat_3[0][0] conv_4[0][0] __________________________________________________________________________________________________ conv_5 (Conv1D) (None, None, 15) 4815 concat_4[0][0] __________________________________________________________________________________________________ concat_5 (Concatenate) (None, None, 79) 0 concat_4[0][0] conv_5[0][0] __________________________________________________________________________________________________ conv_6 (Conv1D) (None, None, 15) 5940 concat_5[0][0] __________________________________________________________________________________________________ concat_6 (Concatenate) (None, None, 94) 0 concat_5[0][0] conv_6[0][0] __________________________________________________________________________________________________ conv_7 (Conv1D) (None, None, 15) 7065 concat_6[0][0] __________________________________________________________________________________________________ concat_7 (Concatenate) (None, None, 109) 0 concat_6[0][0] conv_7[0][0] __________________________________________________________________________________________________ conv_8 (Conv1D) (None, None, 15) 8190 concat_7[0][0] __________________________________________________________________________________________________ concat_8 (Concatenate) (None, None, 124) 0 concat_7[0][0] conv_8[0][0] __________________________________________________________________________________________________ conv_9 (Conv1D) (None, None, 15) 9315 concat_8[0][0] __________________________________________________________________________________________________ concat_9 (Concatenate) (None, None, 139) 0 concat_8[0][0] conv_9[0][0] __________________________________________________________________________________________________ conv_10 (Conv1D) (None, None, 15) 10440 concat_9[0][0] __________________________________________________________________________________________________ tf_op_layer_NotEqual (TensorFlo [(None, None, 4)] 0 sequence[0][0] __________________________________________________________________________________________________ concat_10 (Concatenate) (None, None, 154) 0 concat_9[0][0] conv_10[0][0] __________________________________________________________________________________________________ tf_op_layer_Any (TensorFlowOpLa [(None, None)] 0 tf_op_layer_NotEqual[0][0] __________________________________________________________________________________________________ logits (GlobalAveragePooling1D) (None, 154) 0 concat_10[0][0] tf_op_layer_Any[0][0] __________________________________________________________________________________________________ dropout (Dropout) (None, 154) 0 logits[0][0] __________________________________________________________________________________________________ mean_and_std (Dense) (None, 2) 310 dropout[0][0] __________________________________________________________________________________________________ independent_normal (Independent ((None, 1), (None, 1 0 mean_and_std[0][0] ================================================================================================== Total params: 53,965 Trainable params: 53,965 Non-trainable params: 0 __________________________________________________________________________________________________
query = """
select t.species, s.sequence, s.length, t.growth_tmp
from sequences as s
inner join species_traits as t
on t.species_taxid = s.species_taxid
where s.species_taxid = ?
"""
species_taxid = 145261
species_df = pd.read_sql(query, engine, params=(species_taxid,))
len(species_df)
2649
species = species_df['species'].iloc[0]
actual_tmp = species_df['growth_tmp'].iloc[0]
species, actual_tmp
('Methanothermobacter wolfeii', 60.0)
_, ax = plt.subplots(1, 1, figsize=(12, 6))
ax.hist(species_df['length'], bins=100, log=True);
actual_tmp = species_df['growth_tmp'].iloc[0]
actual_tmp
60.0
temperatures, mean, std = load_growth_temperatures(engine)
max_sequence_length = metadata.get('max_sequence_length', 5000)
species_seq = SpeciesSequence(
engine,
species_taxid=species_taxid,
batch_size=64,
temperatures=temperatures,
mean=mean,
std=std,
alphabet=ALPHABET_DNA,
max_sequence_length=max_sequence_length,
random_seed=metadata['seed'],
)
compile_regression_model(
model,
learning_rate=1e-4,
)
model.load_weights(model_path)
temperature_range = np.arange(-35, 145, 0.5)
temperature_range_norm = normalize(temperature_range, mean, std)
%%time
cursor = 0
n_sequences = len(species_seq.rowids)
predictions = np.zeros((n_sequences, 1))
log_probabilities = np.zeros((n_sequences, len(temperature_range)))
means = []
for i in range(len(species_seq)):
if (i + 1) % 10 == 0:
print(f'{i+1} / {len(species_seq)}')
x_batch, _, _ = species_seq[i]
a = cursor
b = cursor + len(x_batch)
dist = model(x_batch)
y_pred_norm = dist.mean().numpy()
y_pred = denormalize(y_pred_norm, mean, std)
predictions[a:b] = y_pred
for j, t in enumerate(temperature_range_norm):
log_probs = dist.log_prob(t).numpy()
log_probabilities[a:b, j] = log_probs
for log_prob in log_probabilities[a:b]:
probs_un = np.exp(scipy.special.logsumexp([log_prob], axis=0))
probs_ = probs_un / np.sum(probs_un)
mean_ = np.average(temperature_range, weights=probs_)
means.append(mean_)
cursor += len(x_batch)
10 / 42 20 / 42 30 / 42 40 / 42 CPU times: user 28.7 s, sys: 22.4 s, total: 51.1 s Wall time: 15.4 s
probs_u = np.exp(scipy.special.logsumexp(log_probabilities, axis=0))
probs = probs_u / np.sum(probs_u)
mode = temperature_range[np.argmax(probs)]
mean = np.average(temperature_range, weights=probs)
actual_tmp, mean, mode
(60.0, 63.14527124400265, 64.0)
variance = np.average(
[(t - mean)**2 for t in temperature_range],
weights=probs,
)
std = np.sqrt(variance)
std
25.098262571010405
_, ax = plt.subplots(1, 1, figsize=(12, 6))
ax.plot(temperature_range, probs, label='PDF');
ax.axvline(mean, color=palette[1], label='Prediction');
ax.axvline(actual_tmp, color='green', label='Actual');
x_std_ix = [i for i, t in enumerate(temperature_range) if t >= mean - std and t <= mean + std]
ax.fill_between(temperature_range[x_std_ix], probs[x_std_ix], color=palette[0], alpha=0.2, label='Std');
ax.legend();
_, ax = plt.subplots(1, 1, figsize=(12, 6))
ax.hist(means, bins=50);
ax.axvline(actual_tmp, color='red');
val_df = pd.read_csv(validation_csv).set_index('species_taxid')
val_df.head(10)
growth_tmp_actual | growth_tmp_prediction | growth_tmp_std | |
---|---|---|---|
species_taxid | |||
7 | 30.00 | 36.67 | 15.47 |
14 | 74.15 | 62.08 | 24.86 |
24 | 27.00 | 27.02 | 19.60 |
35 | 27.00 | 40.85 | 15.06 |
63 | 30.00 | 35.35 | 16.61 |
114 | 30.00 | 39.10 | 15.19 |
128 | 42.80 | 40.27 | 16.59 |
134 | 30.00 | 34.78 | 16.27 |
154 | 65.00 | 60.68 | 20.76 |
167 | 37.00 | 42.41 | 20.00 |
val_df['growth_tmp_std'].mean()
18.493841059602648
abs_diff = np.abs(val_df['growth_tmp_prediction'] - val_df['growth_tmp_actual'])
mae = np.mean(abs_diff)
mae
6.790209713024283
d = (val_df['growth_tmp_prediction'] - val_df['growth_tmp_actual'])**2
rmse = np.sqrt(np.mean(d))
rmse
8.606030680450058
def compute_mae_per_bin(df, bins):
res = []
for s, e in zip(bins, bins[1:]):
bin_df = df[
(df['growth_tmp_actual'] >= s) &
(df['growth_tmp_actual'] < e)
]
diff = np.abs(bin_df['growth_tmp_prediction'] - bin_df['growth_tmp_actual'])
mae_bin = np.mean(diff)
res.append(mae_bin)
return res
bins = np.array([
4, 10, 16, 19, 22, 25, 28, 31, 34, 37, 40,
43, 46, 49, 52, 55, 58, 61, 64, 67, 70, 73,
76, 82, 85, 88, 91, 106,
])
mae_per_bin = compute_mae_per_bin(val_df, bins)
_, ax = plt.subplots(1, 1, figsize=(12, 6))
ax.bar(bins[:-1], mae_per_bin);
def plot_true_vs_prediction_per_specie(y_true, y_pred, ax=None):
if ax is None:
_, ax = plt.subplots(1, 1, figsize=(8, 8))
palette = sns.color_palette()
x_straight = range(int(np.min(y_true)) - 1, int(np.max(y_true)) + 1)
ax.plot(x_straight, x_straight, '--', color=palette[1], linewidth=2)
ax.scatter(y_true, y_pred, color=palette[0], alpha=0.5, label='Predictions')
ax.set_xlabel('Temperature °C (truth)')
ax.set_ylabel('Temperature °C (prediction)')
return ax
def plot_model_results(val_df, ax=None):
y_true = val_df['growth_tmp_actual'].values
y_pred = val_df['growth_tmp_prediction'].values
ax = plot_true_vs_prediction_per_specie(y_true, y_pred, ax=ax);
pred_mean = val_df[['growth_tmp_actual', 'growth_tmp_prediction']].groupby('growth_tmp_actual').mean()
std_mean = val_df[['growth_tmp_actual', 'growth_tmp_std']].groupby('growth_tmp_actual').mean().reset_index()
x_fill = std_mean['growth_tmp_actual'].values
x_base = pred_mean['growth_tmp_prediction'].values
y_fill_1 = x_base - std_mean['growth_tmp_std'].values
y_fill_2 = x_base + std_mean['growth_tmp_std'].values
ax.fill_between(x_fill, y_fill_1, y_fill_2, alpha=0.2, color=palette[1], label='Uncertainty');
ax.legend()
return ax
plot_model_results(val_df);
def discretize_temperatures(temperatures):
temp_cats = np.zeros((len(temperatures),), dtype=np.int8)
for i, t in enumerate(temperatures):
if t < 20:
temp_cats[i] = 0
elif t >= 20 and t < 45:
temp_cats[i] = 1
elif t >= 45 and t < 75:
temp_cats[i] = 2
else:
temp_cats[i] = 3
return temp_cats
y_true_cat = discretize_temperatures(y_true)
y_pred_cat = discretize_temperatures(y_pred)
l = len(y_true_cat)
100 * np.sum([1 if y_true_cat[i] == y_pred_cat[i] else 0 for i in range(l)]) / l
87.41721854304636
def accuracy_per_class(y_true_cat, y_pred_cat):
acc_per_class = {}
for i in range(4):
cat_ix = [ix for ix, c in enumerate(y_true_cat) if i == c]
acc_per_class[i] = 100 * np.sum([1 if y_true_cat[ix] == y_pred_cat[ix] else 0 for ix in cat_ix]) / len(cat_ix)
return acc_per_class
accuracy_per_class(y_true_cat, y_pred_cat)
{0: 0.0, 1: 91.97994987468671, 2: 72.22222222222223, 3: 23.076923076923077}
common_patterns_path = os.path.join(os.getcwd(), 'data/condensed_traits/gc_content_IVYWREL_content.csv')
common_patterns_df = pd.merge(
pd.read_csv(common_patterns_path).set_index('species_taxid'),
val_df,
on='species_taxid',
)
common_patterns_df.head()
gc_content | IVYWREL_content | growth_tmp_actual | growth_tmp_prediction | growth_tmp_std | |
---|---|---|---|---|---|
species_taxid | |||||
7 | 0.674153 | 0.128590 | 30.00 | 36.67 | 15.47 |
14 | 0.338561 | 0.141845 | 74.15 | 62.08 | 24.86 |
24 | 0.448970 | 0.126567 | 27.00 | 27.02 | 19.60 |
35 | 0.707531 | 0.130083 | 27.00 | 40.85 | 15.06 |
63 | 0.637449 | 0.131033 | 30.00 | 35.35 | 16.61 |
common_patterns_df[common_patterns_df['growth_tmp_actual'] < 20]
gc_content | IVYWREL_content | growth_tmp_actual | growth_tmp_prediction | growth_tmp_std | |
---|---|---|---|---|---|
species_taxid | |||||
59600 | 0.340564 | 0.126672 | 17.0 | 38.36 | 21.62 |
80854 | 0.398872 | 0.127822 | 13.5 | 26.69 | 19.87 |
266749 | 0.354447 | 0.126771 | 16.0 | 36.78 | 20.08 |
326544 | 0.408469 | 0.127801 | 15.0 | 29.04 | 20.43 |
1835254 | 0.455350 | 0.123148 | 18.0 | 31.34 | 20.45 |
def compute_and_plot_pattern_against_ogt(y_actual, pattern_data, pattern_type, ax=None):
palette = sns.color_palette()
X_bins = sm.add_constant(pattern_data[:,np.newaxis])
model = sm.OLS(y_actual, X_bins)
results = model.fit()
std, upper, lower = wls_prediction_std(results)
a = results.params[1]
b = results.params[0]
y_fit = a * pattern_data + b
ax = plot_true_vs_prediction_per_specie(y_actual, y_fit, ax=ax)
df_data = [
[y_actual[i], y_fit[i], std[i]]
for i in range(len(y_actual))
]
df = pd.DataFrame(df_data, columns=['growth_tmp_actual', 'growth_tmp_prediction', 'growth_tmp_std'])
pred_mean = df[['growth_tmp_actual', 'growth_tmp_prediction']].groupby('growth_tmp_actual').mean()
std_mean = df[['growth_tmp_actual', 'growth_tmp_std']].groupby('growth_tmp_actual').mean().reset_index()
x_fill = std_mean['growth_tmp_actual'].values
x_base = pred_mean['growth_tmp_prediction'].values
y_fill_1 = x_base - std_mean['growth_tmp_std'].values
y_fill_2 = x_base + std_mean['growth_tmp_std'].values
ax.fill_between(x_fill, y_fill_1, y_fill_2, alpha=0.2, color=palette[1], label='Uncertainty');
ax.set_title(f'{pattern_type} model')
ax.legend()
return ax, a, b, y_fit
_, axes = plt.subplots(1, 2, figsize=(18, 8))
y_actual = common_patterns_df['growth_tmp_actual'].values
plot_model_results(val_df, ax=axes[0]);
axes[0].set_title('1D convolutions model')
compute_and_plot_pattern_against_ogt(
y_actual=y_actual,
pattern_data=common_patterns_df['IVYWREL_content'].values,
pattern_type='IVYWREL content',
ax=axes[1],
);
attr_IVYWREL_path = os.path.join(os.getcwd(), 'data/condensed_traits/IVYWREL_avg_attributions.csv')
attr_IVYWREL_df = pd.merge(
pd.read_csv(attr_IVYWREL_path).set_index('species_taxid'),
val_df,
on='species_taxid',
)
attr_IVYWREL_df.head()
attr_1_c1 | attr_1_c2 | attr_1_c3 | mean_1 | attr_2_c1 | attr_2_c2 | attr_2_c3 | mean_2 | attr_3_c1 | attr_3_c2 | attr_3_c3 | mean_3 | growth_tmp_actual | growth_tmp_prediction | growth_tmp_std | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
species_taxid | |||||||||||||||
7 | -0.001117 | -0.001378 | 0.001779 | -0.000239 | -0.000571 | -0.000700 | 0.000877 | -0.000131 | -0.001130 | -0.001389 | 0.001767 | -0.000251 | 30.00 | 36.67 | 15.47 |
14 | 0.000807 | -0.000448 | 0.001343 | 0.000567 | 0.000502 | -0.000187 | 0.000717 | 0.000344 | 0.000905 | -0.000411 | 0.001389 | 0.000628 | 74.15 | 62.08 | 24.86 |
24 | -0.000322 | -0.000879 | -0.000476 | -0.000559 | -0.000177 | -0.000445 | -0.000247 | -0.000290 | -0.000338 | -0.000884 | -0.000485 | -0.000569 | 27.00 | 27.02 | 19.60 |
35 | -0.001712 | -0.004290 | 0.001379 | -0.001541 | -0.000872 | -0.002155 | 0.000668 | -0.000786 | -0.001728 | -0.004300 | 0.001357 | -0.001557 | 27.00 | 40.85 | 15.06 |
63 | -0.000380 | -0.000280 | 0.000765 | 0.000035 | -0.000203 | -0.000152 | 0.000363 | 0.000003 | -0.000392 | -0.000292 | 0.000746 | 0.000021 | 30.00 | 35.35 | 16.61 |
_, axes = plt.subplots(1, 2, figsize=(18, 8))
y_actual = common_patterns_df['growth_tmp_actual'].values
_, _, _, y_1 = compute_and_plot_pattern_against_ogt(
y_actual=y_actual,
pattern_data=common_patterns_df['IVYWREL_content'].values,
pattern_type='IVYWREL content',
ax=axes[0],
);
_, _, _, y_2 = compute_and_plot_pattern_against_ogt(
y_actual=y_actual,
pattern_data=attr_IVYWREL_df['mean_2'].values,
pattern_type='IVYWREL attribution',
ax=axes[1],
);
np.corrcoef(
[v for i, v in enumerate(y_1) if y_actual[i] > 40],
[v for i, v in enumerate(y_2) if y_actual[i] > 40],
)
array([[1. , 0.55500345], [0.55500345, 1. ]])