import collections
import os
import json
import pandas as pd
import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt
import seaborn as sns
from tensorflow.keras.models import load_model
import statsmodels.api as sm
if os.getcwd().endswith('notebook'):
os.chdir('..')
from rna_learn.transform import (
sequence_embedding,
normalize, denormalize,
make_dataset_balanced,
one_hot_encode_classes,
split_train_test_set,
)
from rna_learn.load import load_dataset
from rna_learn.model import conv1d_densenet_regression_model, compile_regression_model
sns.set(palette='colorblind', font_scale=1.3)
alphabet = ['A', 'T', 'G', 'C']
classes = ['psychrophilic', 'mesophilic', 'thermophilic']
run_id = 'run_94oi0'
model_path = os.path.join(os.getcwd(), f'saved_models_regression/{run_id}/model.h5')
hyperparameters_path = os.path.join(os.getcwd(), f'saved_models_regression/{run_id}/metadata.json')
test_set_path = os.path.join(os.getcwd(), 'data/dataset_test.csv')
with open(hyperparameters_path) as f:
metadata = json.load(f)
metadata
{'run_id': 'run_94oi0', 'alphabet': ['A', 'T', 'G', 'C'], 'classes': ['psychrophilic', 'mesophilic', 'thermophilic'], 'model_type': 'conv1d_densenet', 'n_epochs': 197, 'growth_rate': 19, 'n_layers': 5, 'kernel_sizes': [2, 3, 10, 20, 30], 'l2_reg': 0.0001, 'dropout': 0.5, 'seed': 4556, 'val_loss': 1.202218529340383, 'val_mae': 15.57646369934082}
model = conv1d_densenet_regression_model(
alphabet_size=len(alphabet),
growth_rate=metadata['growth_rate'],
n_layers=metadata['n_layers'],
kernel_sizes=metadata['kernel_sizes'],
l2_reg=metadata['l2_reg'],
dropout=metadata['dropout'],
)
compile_regression_model(model, learning_rate=1e-4)
model.load_weights(model_path)
model.summary()
Model: "model" __________________________________________________________________________________________________ Layer (type) Output Shape Param # Connected to ================================================================================================== sequence (InputLayer) [(None, None, 4)] 0 __________________________________________________________________________________________________ conv_1 (Conv1D) (None, None, 19) 171 sequence[0][0] __________________________________________________________________________________________________ concat_1 (Concatenate) (None, None, 23) 0 sequence[0][0] conv_1[0][0] __________________________________________________________________________________________________ conv_2 (Conv1D) (None, None, 19) 1330 concat_1[0][0] __________________________________________________________________________________________________ concat_2 (Concatenate) (None, None, 42) 0 concat_1[0][0] conv_2[0][0] __________________________________________________________________________________________________ conv_3 (Conv1D) (None, None, 19) 7999 concat_2[0][0] __________________________________________________________________________________________________ concat_3 (Concatenate) (None, None, 61) 0 concat_2[0][0] conv_3[0][0] __________________________________________________________________________________________________ conv_4 (Conv1D) (None, None, 19) 23199 concat_3[0][0] __________________________________________________________________________________________________ concat_4 (Concatenate) (None, None, 80) 0 concat_3[0][0] conv_4[0][0] __________________________________________________________________________________________________ conv_5 (Conv1D) (None, None, 19) 45619 concat_4[0][0] __________________________________________________________________________________________________ concat_5 (Concatenate) (None, None, 99) 0 concat_4[0][0] conv_5[0][0] __________________________________________________________________________________________________ logits (GlobalAveragePooling1D) (None, 99) 0 concat_5[0][0] __________________________________________________________________________________________________ dropout (Dropout) (None, 99) 0 logits[0][0] __________________________________________________________________________________________________ mean_and_std (Dense) (None, 2) 200 dropout[0][0] __________________________________________________________________________________________________ independent_normal (Independent ((None, 1), (None, 1 0 mean_and_std[0][0] ================================================================================================== Total params: 78,518 Trainable params: 78,518 Non-trainable params: 0 __________________________________________________________________________________________________
tf.keras.utils.plot_model(model, show_shapes=True)
dataset_df = load_dataset(test_set_path, alphabet, secondary=False)
dataset_df.head()
specie_name | seqid | gene_name | start_inclusive | end_exclusive | length | strand | temperature | temperature_range | sequence | gc_content | ag_content | gt_content | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|
0 | Micropolyspora internatus | NC_013159.1 | rnpB | 742238 | 742645 | 407 | + | 45.0 | thermophilic | CGAGTTGGCAGGGCGGCCGCGGCCGAGGGCATCGTCTCGACGTCTT... | 0.685504 | 0.562654 | 0.508600 |
1 | Listonella anguillarum | NC_015633.1 | recA | 2753532 | 2754579 | 1047 | - | 20.0 | psychrophilic | ATGGACGAAAATAAGCAGAAGGCGCTAGCCGCAGCACTGGGTCAAA... | 0.442216 | 0.540592 | 0.510029 |
2 | Comamonas badia | NZ_AXVM01000006.1 | rpsR | 29593 | 29878 | 285 | - | 28.0 | mesophilic | TTGACCATGTTCAGGAAATTCAACAAGAATGGCAAGAACGGCAAGC... | 0.568421 | 0.494737 | 0.414035 |
3 | Acetobacter orientalis | NZ_BAMX01000009.1 | tsaD | 11584 | 12799 | 1215 | - | 30.0 | mesophilic | ATGGCGGTCAGCAGCCAGTTTTCAGGCTTACCCGGCACCCCTCACA... | 0.596708 | 0.469959 | 0.476543 |
4 | Alicyclobacillus kakegawensis | NZ_BCRP01000001.1 | tsaD | 64880 | 65918 | 1038 | - | 50.0 | thermophilic | TTGCTCCTGTTGGGCATTGAGACGAGTTGCGACGAGACCGCCGCGG... | 0.680154 | 0.516378 | 0.564547 |
dataset_df.shape
(3700, 13)
y = dataset_df['temperature'].values.astype(np.float64)
sequences = dataset_df['sequence'].values
x = sequence_embedding(sequences, alphabet)
mean, std = np.mean(y), np.std(y)
y_norm = normalize(y, mean, std)
%%time
model.evaluate(x, y_norm, verbose=0)
CPU times: user 1min 42s, sys: 1.83 s, total: 1min 44s Wall time: 15.9 s
1.2118407927332697
%%time
y_hat = model(x)
CPU times: user 1min 58s, sys: 41.5 s, total: 2min 40s Wall time: 34.7 s
y_mean = denormalize(y_hat.mean().numpy(), mean, std)
y_std = denormalize(y_hat.stddev().numpy(), mean, std)
y_mean.shape
(3700, 1)
mae = np.mean(np.abs(y_mean - y))
rmse = np.sqrt(np.mean((y_mean - y) **2))
print(f'Mean Absolute Error : {mae:.2f}')
print(f'Root Mean Square Error: {rmse:.2f}')
Mean Absolute Error : 15.96 Root Mean Square Error: 20.39
y_baseline = np.array([np.mean(y)] * len(y))
mae_baseline = np.mean(np.abs(y_baseline - y))
rmse_baseline = np.sqrt(np.mean((y_baseline - y) ** 2))
print(f'Baseline Mean Absolute Error : {mae_baseline:.2f}')
print(f'Baseline Root Mean Square Error: {rmse_baseline:.2f}')
Baseline Mean Absolute Error : 14.09 Baseline Root Mean Square Error: 16.77
def plot_true_vs_prediction_scatter(y_true, y_pred):
f, ax = plt.subplots(1, 1, figsize=(8, 8))
palette = sns.color_palette()
x_straight = range(0, 80)
ax.plot(x_straight, x_straight, '--', color=palette[1], linewidth=2)
predictions_per_temp = collections.defaultdict(list)
for i in range(len(y_true)):
y_t = y_true[i]
y_p = y_pred[i]
predictions_per_temp[y_t].append(y_p)
label = None
if i == 0:
label = 'Actual'
ax.plot(y_t, y_p, 'o', color=palette[0], alpha=0.1, label=label)
for i in range(len(y_true)):
y_t = y_true[i]
y_p_mean = np.mean(predictions_per_temp[y_t])
label = None
if i == 0:
label = 'Average'
ax.plot(y_t, y_p_mean, 'o', color=palette[1], label=label)
ax.set_xlim(0, 80)
ax.set_ylim(0, 80)
ax.set_xlabel('Temperature °C (truth)')
ax.set_ylabel('Temperature °C (prediction)')
ax.legend()
return f, ax
_, ax = plot_true_vs_prediction_scatter(y, y_mean)
ax.set_title(f'Convolutional DenseNet model RMSE: {rmse:.2f} °C');
def plot_mae_per_temperature(y_true, y_pred):
f, ax = plt.subplots(1, 1, figsize=(12, 6))
palette = sns.color_palette()
error_per_temp = collections.defaultdict(list)
for i in range(len(y_true)):
y_t = y_true[i]
y_p = y_pred[i]
err = np.abs(y_t - y_p)
error_per_temp[y_t].append(err)
label = None
if i == 0:
label = 'Actual'
ax.plot(y_t, err, 'o', color=palette[0], alpha=0.1, label=label)
for i in range(len(y_true)):
y_t = y_true[i]
y_err_mean = np.mean(error_per_temp[y_t])
label = None
if i == 0:
label = 'Average'
ax.plot(y_t, y_err_mean, 'o', color=palette[1], label=label)
ax.set_xlabel('Temperature °C (truth)')
ax.set_ylabel('Absolute error (°C)')
ax.legend()
return f, ax
f, ax = plot_mae_per_temperature(y, y_mean)
ax.set_ylim(None, 60);
def plot_std_per_temperature(y_true, y_std):
f, ax = plt.subplots(1, 1, figsize=(12, 6))
palette = sns.color_palette()
std_per_temp = collections.defaultdict(list)
for i in range(len(y_true)):
y_t = y_true[i]
std_per_temp[y_t].append(y_std[i])
for i in range(len(y_true)):
y_t = y_true[i]
y_std_mean = np.mean(std_per_temp[y_t])
ax.plot(y_t, y_std_mean, 'o', color=palette[0])
ax.set_xlabel('Temperature °C (truth)')
ax.set_ylabel('Average standard deviation (°C)')
return f, ax
plot_std_per_temperature(y, y_std);
def mae_per_gene(dataset_df, y_true, y_pred):
f, ax = plt.subplots(1, 1, figsize=(6, 12))
palette = sns.color_palette()
genes = np.array(sorted(dataset_df['gene_name'].unique().tolist()))
results = []
for i, gene in enumerate(genes):
df = dataset_df[dataset_df['gene_name'] == gene]
indices = df.index
y_p = y_pred[indices]
y_t = y_true[indices]
results.append(np.mean(np.abs(y_p - y_t)))
sorted_idx = np.argsort([np.sum(r) for r in results]).tolist()
cm = pd.DataFrame(
np.array(results)[sorted_idx],
index=[f'{g}' for i, g in enumerate(genes[sorted_idx])],
columns=[''],
)
sns.heatmap(cm, cmap="Greys_r", annot=True, fmt='.2f', cbar=False, ax=ax);
plt.yticks(rotation=0)
ax.set_title('Mean average error (°C)')
return f, ax, genes[sorted_idx]
_, _, sorted_genes = mae_per_gene(dataset_df, y, y_mean)
def mae_per_temperature_band(dataset_df, y_true, y_pred):
f, ax = plt.subplots(1, 1, figsize=(6, 10))
palette = sns.color_palette()
bucket_size = 5
bucket_starts = list(range(2, 76, bucket_size))
temp_buckets = [f'{s}-{s + (bucket_size - 1)} °C' for s in bucket_starts]
results = []
for i, bucket_start in enumerate(bucket_starts):
df = dataset_df[
(dataset_df['temperature'] >= bucket_start) &
(dataset_df['temperature'] < bucket_start + bucket_size)
]
indices = df.index
y_p = y_pred[indices]
y_t = y_true[indices]
results.append(np.mean(np.abs(y_p - y_t)))
cm = pd.DataFrame(
np.array(results),
index=[f'{g}' for i, g in enumerate(temp_buckets)],
columns=[''],
)
sns.heatmap(cm, cmap="Greys_r", annot=True, fmt='.2f', cbar=False, ax=ax, linewidths=1, linecolor='#cccccc');
plt.yticks(rotation=0)
ax.set_title('MAE per temperature band')
mae_per_temperature_band(dataset_df, y, y_mean)
species_dataset = os.path.join(os.getcwd(), 'data/dataset_balanced_test.csv')
species_dataset_df = load_dataset(species_dataset, alphabet, secondary=False)
species_dataset_df.head()
specie_name | seqid | gene_name | start_inclusive | end_exclusive | length | strand | temperature | temperature_range | sequence | gc_content | ag_content | gt_content | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|
0 | Acetobacter pasteurianus | NC_013209.1 | recA | 43040 | 44075 | 1035 | + | 26.0 | mesophilic | ATGGTAAAAATGGATAAGGCAAAGGCTCTCGAAGGCGCGCTGGGGC... | 0.526570 | 0.549758 | 0.543961 |
1 | Acetobacter pasteurianus | NC_013209.1 | ffs | 51809 | 51907 | 98 | + | 26.0 | mesophilic | GGAAAGTCGGCAGTGGACGGATACCTTGCCAACCCGGTCAGATCCG... | 0.571429 | 0.520408 | 0.520408 |
2 | Acetobacter pasteurianus | NC_013209.1 | map | 288964 | 289765 | 801 | - | 26.0 | mesophilic | ATGGCCGGCAGAGGCGGAATTATTCTGCATACTGAAGAAGATTTTA... | 0.516854 | 0.503121 | 0.506866 |
3 | Acetobacter pasteurianus | NC_013209.1 | murA | 402403 | 403666 | 1263 | - | 26.0 | mesophilic | ATGGATCGTTTCATCATCCGGGGCGGACGCCCCCTGCACGGCGAAA... | 0.588282 | 0.501188 | 0.510689 |
4 | Acetobacter pasteurianus | NC_013209.1 | pyk | 638524 | 639985 | 1461 | + | 26.0 | mesophilic | ATGGCTGAATCCACACAACAGGGTTCGGGCGCAGAACAGGCGCAAA... | 0.607118 | 0.517454 | 0.531143 |
def predictions_per_specie(species_dataset_df, model, mean, std, alphabet, selected_genes=None):
species = sorted(species_dataset_df['specie_name'].unique().tolist())
predictions, actuals = [], []
for i, specie in enumerate(species):
print(f'{i + 1} / {len(species)}: {specie}')
if selected_genes is not None:
df = species_dataset_df[
(species_dataset_df['specie_name'] == specie) &
(species_dataset_df['gene_name'].isin(selected_genes))
]
else:
df = species_dataset_df[species_dataset_df['specie_name'] == specie]
y_s = df['temperature'].values.astype(np.float64)[0]
sequences_s = df['sequence'].values
x_s = sequence_embedding(sequences_s, alphabet)
y_s_hat = model(x_s)
y_s_mean = denormalize(y_s_hat.mean().numpy(), mean, std)
y_pred = np.mean(y_s_mean)
predictions.append(y_pred)
actuals.append(y_s)
return np.array(species), np.array(predictions), np.array(actuals)
species, predictions, actuals = predictions_per_specie(species_dataset_df, model, mean, std, alphabet)
1 / 238: Acetobacter nitrogenifigens 2 / 238: Acetobacter pasteurianus 3 / 238: Acidaminococcus intestini 4 / 238: Acidiplasma cupricumulans 5 / 238: Acidithiobacillus caldus 6 / 238: Acidovorax avenae 7 / 238: Acidovorax citrulli 8 / 238: Acidovorax soli 9 / 238: Acinetobacter seifertii 10 / 238: Actinocatenispora sera 11 / 238: Adhaeribacter aquaticus 12 / 238: Aeromonas bestiarum 13 / 238: Agathobacter rectalis 14 / 238: Algoriphagus aquimarinus 15 / 238: Algoriphagus ratkowskyi 16 / 238: Alicyclobacillus acidocaldarius 17 / 238: Alicyclobacillus acidoterrestris 18 / 238: Alicyclobacillus herbarius 19 / 238: Aliivibrio salmonicida 20 / 238: Alkaliphilus peptidifermentans 21 / 238: Allobaculum stercoricanis 22 / 238: Alloscardovia criceti 23 / 238: Alteromonas stellipolaris 24 / 238: Anaerobacillus alkalidiazotrophicus 25 / 238: Anaerolinea thermolimosa 26 / 238: Aquisalimonas asiatica 27 / 238: Arcanobacterium haemolyticum 28 / 238: Archangium gephyra 29 / 238: Ardenticatena maritima 30 / 238: Arthrobacter agilis 31 / 238: Asaccharospora irregularis 32 / 238: Asticcacaulis biprosthecium 33 / 238: Aurantimonas coralicida 34 / 238: Aureimonas frigidaquae 35 / 238: Azotobacter chroococcum 36 / 238: Bacillus atrophaeus 37 / 238: Bacillus infantis 38 / 238: Bacillus nealsonii 39 / 238: Bacillus thuringiensis 40 / 238: Bacillus trypoxylicola 41 / 238: Bacteroides fluxus 42 / 238: Bavariicoccus seileri 43 / 238: Bifidobacterium animalis 44 / 238: Bifidobacterium stellenboschense 45 / 238: Bordetella bronchiseptica 46 / 238: Bordetella hinzii 47 / 238: Bordetella holmesii 48 / 238: Bosea thiooxidans 49 / 238: Brachybacterium paraconglomeratum 50 / 238: Brochothrix campestris 51 / 238: Brochothrix thermosphacta 52 / 238: Burkholderia ambifaria 53 / 238: Burkholderia diffusa 54 / 238: Caldicellulosiruptor acetigenus 55 / 238: Caldicellulosiruptor hydrothermalis 56 / 238: Caldicellulosiruptor kristjanssonii 57 / 238: Caldicellulosiruptor kronotskyensis 58 / 238: Caldilinea aerophila 59 / 238: Caloramator mitchellensis 60 / 238: Caloramator proteoclasticus 61 / 238: Campylobacter curvus 62 / 238: Campylobacter sputorum 63 / 238: Candidimonas nitroreducens 64 / 238: Catenibacterium mitsuokai 65 / 238: Cedecea neteri 66 / 238: Cellulomonas flavigena 67 / 238: Cellulophaga tyrosinoxydans 68 / 238: Chlorobaculum parvum 69 / 238: Chryseobacterium culicis 70 / 238: Chryseobacterium piscicola 71 / 238: Clostridium akagii 72 / 238: Clostridium amylolyticum 73 / 238: Clostridium cellulosi 74 / 238: Clostridium fimetarium 75 / 238: Clostridium indolis 76 / 238: Clostridium lundense 77 / 238: Clostridium puniceum 78 / 238: Clostridium saccharobutylicum 79 / 238: Clostridium septicum 80 / 238: Clostridium spiroforme 81 / 238: Clostridium tetani 82 / 238: Clostridium thermobutyricum 83 / 238: Comamonas badia 84 / 238: Coprothermobacter proteolyticus 85 / 238: Coriobacterium glomerans 86 / 238: Corynebacterium ammoniagenes 87 / 238: Corynebacterium aquilae 88 / 238: Corynebacterium camporealensis 89 / 238: Corynebacterium caspium 90 / 238: Corynebacterium tuberculostearicum 91 / 238: Cyanobium gracile 92 / 238: Defluviitoga tunisiensis 93 / 238: Deinococcus maricopensis 94 / 238: Deinococcus radiodurans 95 / 238: Desulfobacterium oleovorans 96 / 238: Desulfobulbus propionicus 97 / 238: Desulfococcus multivorans 98 / 238: Desulfoplanes formicivorans 99 / 238: Desulfovermiculus halophilus 100 / 238: Desulfovibrio bastinii 101 / 238: Desulfovibrio brasiliensis 102 / 238: Desulfovibrio frigidus 103 / 238: Desulfurella acetivorans 104 / 238: Desulfurivibrio alkaliphilus 105 / 238: Desulfurobacterium indicum 106 / 238: Dickeya zeae 107 / 238: Dietzia maris 108 / 238: Diplorickettsia massiliensis 109 / 238: Elstera litoralis 110 / 238: Enterococcus caccae 111 / 238: Enterococcus faecium 112 / 238: Enterococcus quebecensis 113 / 238: Enterococcus villorum 114 / 238: Ewingella americana 115 / 238: Exiguobacterium profundum 116 / 238: Ferrimonas sediminum 117 / 238: Ferroplasma acidiphilum 118 / 238: Fictibacillus solisalsi 119 / 238: Flaviramulus basaltis 120 / 238: Flavobacterium antarcticum 121 / 238: Flavobacterium branchiophilum 122 / 238: Flavobacterium denitrificans 123 / 238: Flavobacterium flevense 124 / 238: Flavobacterium frigoris 125 / 238: Flavobacterium gelidilacus 126 / 238: Flavobacterium johnsoniae 127 / 238: Flavobacterium psychrophilum 128 / 238: Franconibacter helveticus 129 / 238: Fusobacterium nucleatum 130 / 238: Gaetbulibacter saemankumensis 131 / 238: Gelidibacter algens 132 / 238: Geobacillus galactosidasius 133 / 238: Geobacillus icigianus 134 / 238: Geobacillus vulcani 135 / 238: Gillisia limnaea 136 / 238: Glaciecola punicea 137 / 238: Globicatella sanguinis 138 / 238: Gracilibacillus lacisalsi 139 / 238: Granulicatella elegans 140 / 238: Granulicella rosea 141 / 238: Halobacteriovorax marinus 142 / 238: Haloferax larsenii 143 / 238: Halomonas aquamarina 144 / 238: Hugenholtzia roseola 145 / 238: Hydrogenovibrio crunogenus 146 / 238: Hymenobacter gelipurpurascens 147 / 238: Hymenobacter psychrophilus 148 / 238: Hyunsoonleella jejuensis 149 / 238: Klebsiella aerogenes 150 / 238: Kluyvera ascorbata 151 / 238: Kriegella aquimaris 152 / 238: Kyrpidia tusciae 153 / 238: Lactobacillus amylolyticus 154 / 238: Lactobacillus cacaonum 155 / 238: Lactobacillus senmaizukei 156 / 238: Lampropedia hyalina 157 / 238: Leisingera aquimarina 158 / 238: Limimonas halophila 159 / 238: Limnochorda pilosa 160 / 238: Listonella anguillarum 161 / 238: Luteibacter rhizovicinus 162 / 238: Lutispora thermophila 163 / 238: Maribacter antarcticus 164 / 238: Marinitoga piezophila 165 / 238: Marinobacter santoriniensis 166 / 238: Marinomonas polaris 167 / 238: Marinovum algicola 168 / 238: Mesoflavibacter zeaxanthinifaciens 169 / 238: Microbacterium indicum 170 / 238: Microterricola viridarii 171 / 238: Microvirga flocculans 172 / 238: Moorella glycerini 173 / 238: Morganella psychrotolerans 174 / 238: Neisseria animaloris 175 / 238: Nitrobacter vulgaris 176 / 238: Nonlabens ulvanivorans 177 / 238: Oceanibacterium hippocampi 178 / 238: Paenibacillus darwinianus 179 / 238: Pandoraea apista 180 / 238: Paraburkholderia graminis 181 / 238: Peptoclostridium litorale 182 / 238: Persephonella marina 183 / 238: Photobacterium kishitanii 184 / 238: Photorhabdus laumondii 185 / 238: Picrophilus oshimae 186 / 238: Planktotalea frisia 187 / 238: Planococcus antarcticus 188 / 238: Planomicrobium glaciei 189 / 238: Polaribacter irgensii 190 / 238: Polaromonas glacialis 191 / 238: Polynucleobacter wuianus 192 / 238: Proteus vulgaris 193 / 238: Providencia rustigianii 194 / 238: Pseudacidovorax intermedius 195 / 238: Pseudoalteromonas tunicata 196 / 238: Pseudothermotoga elfii 197 / 238: Psychrilyobacter atlanticus 198 / 238: Psychrobacter aquaticus 199 / 238: Psychrobacter arcticus 200 / 238: Rhodoferax fermentans 201 / 238: Roseisalinus antarcticus 202 / 238: Rubritepida flocculans 203 / 238: Sediminibacillus halophilus 204 / 238: Serratia nematodiphila 205 / 238: Shewanella algae 206 / 238: Shewanella frigidimarina 207 / 238: Simplicispira psychrophila 208 / 238: Smaragdicoccus niigatensis 209 / 238: Spirosoma panaciterrae 210 / 238: Spirosoma spitsbergense 211 / 238: Sporolituus thermophilus 212 / 238: Stenotrophomonas maltophilia 213 / 238: Sulfurihydrogenibium azorense 214 / 238: Sulfurivirga caldicuralii 215 / 238: Symbiobacterium thermophilum 216 / 238: Taylorella equigenitalis 217 / 238: Tepidanaerobacter syntrophicus 218 / 238: Terriglobus saanensis 219 / 238: Thalassospira profundimaris 220 / 238: Thermobrachium celere 221 / 238: Thermomonospora curvata 222 / 238: Thermosulfurimonas dismutans 223 / 238: Thermus filiformis 224 / 238: Thioalkalivibrio thiocyanoxidans 225 / 238: Thiohalospira halophila 226 / 238: Thiomicrospira aerophila 227 / 238: Tomitella biformata 228 / 238: Tsukamurella pseudospumae 229 / 238: Tuberibacillus calidus 230 / 238: Vibrio cholerae 231 / 238: Vibrio gigantis 232 / 238: Virgibacillus halodenitrificans 233 / 238: Xenorhabdus vietnamensis 234 / 238: Yersinia bercovieri 235 / 238: Yersinia enterocolitica 236 / 238: Yersinia entomophaga 237 / 238: Yersinia frederiksenii 238 / 238: Zobellia galactanivorans
def plot_true_vs_prediction_per_specie(y_true, y_pred):
f, ax = plt.subplots(1, 1, figsize=(8, 8))
palette = sns.color_palette()
x_straight = range(0, 80)
ax.plot(x_straight, x_straight, '--', color=palette[1], linewidth=2)
for i in range(len(y_true)):
y_t = y_true[i]
y_p = y_pred[i]
ax.plot(y_t, y_p, 'o', color=palette[0])
ax.set_xlim(0, 80)
ax.set_ylim(0, 80)
ax.set_xlabel('Temperature °C (truth)')
ax.set_ylabel('Temperature °C (prediction)')
return f, ax
mae_species = np.mean(np.abs(predictions - actuals))
rmse_species = np.sqrt(np.mean((predictions - actuals) **2))
print(f'MAE species : {mae_species:.2f}')
print(f'RMSE species: {rmse_species:.2f}')
print(f'Correlation : {np.corrcoef(actuals, predictions)[0, 1]:.2f}')
MAE species : 9.83 RMSE species: 16.38 Correlation : 0.46
predictions_35 = np.array([35.] * len(actuals))
mae_species_35 = np.mean(np.abs(predictions_35 - actuals))
rmse_species_35 = np.sqrt(np.mean((predictions_35 - actuals) **2))
print(f'MAE 35 °C : {mae_species_35:.2f}')
print(f'RMSE 35 °C: {rmse_species_35:.2f}')
MAE 35 °C : 10.24 RMSE 35 °C: 13.18
f, ax = plot_true_vs_prediction_per_specie(actuals, predictions)
title = f'Predictions aggregated per specie\n\n'
title += f'MAE : {mae_species:.2f} °C | '
title += f'RMSE: {rmse_species:.2f} °C'
ax.set_title(title);