import collections
import os
import pandas as pd
import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt
import seaborn as sns
if os.getcwd().endswith('notebook'):
os.chdir('..')
from rna_learn.model import (
rnn_classification_model,
compile_classification_model,
)
from rna_learn.transform import (
sequence_embedding,
make_dataset_balanced,
one_hot_encode_classes,
split_train_test_set,
normalize,
denormalize,
)
from rna_learn.load import load_rna_nucleotides_dataset
from rna_learn.vae import variational_autoencoder, compile_vae
seed = 123
np.random.seed(seed)
sns.set(palette='colorblind', font_scale=1.3)
rna = 'trna'
alphabet = ['A', 'T', 'G', 'C']
alphabet_size = len(alphabet)
metadata_folder = f'data/rna_temp/tab/{rna}.tab'
sequences_folder = 'data/rna_temp/seq/'
classes = ['psychrophile', 'mesophile', 'thermophile', 'hyperthermophile']
n_classes = len(classes)
n_entries_per_class = 153
output_path = os.path.join(os.getcwd(), 'saved_models', f'seed_{seed}', f'{rna}_classification.h5')
model = rnn_classification_model(alphabet_size=alphabet_size, n_classes=n_classes, n_lstm=2)
compile_classification_model(model, learning_rate=1e-4)
model.load_weights(output_path)
model.summary()
Model: "model" _________________________________________________________________ Layer (type) Output Shape Param # ================================================================= sequence (InputLayer) [(None, None, 4)] 0 _________________________________________________________________ masking (Masking) (None, None, 4) 0 _________________________________________________________________ lstm (LSTM) (None, None, 100) 42000 _________________________________________________________________ lstm_1 (LSTM) (None, 100) 80400 _________________________________________________________________ dense (Dense) (None, 100) 10100 _________________________________________________________________ dropout (Dropout) (None, 100) 0 _________________________________________________________________ dense_1 (Dense) (None, 4) 404 ================================================================= Total params: 132,904 Trainable params: 132,904 Non-trainable params: 0 _________________________________________________________________
%%time
metadata = pd.read_csv(metadata_folder, delimiter='\t')
metadata['category'] = metadata['temp.cat']
sequences = load_rna_nucleotides_dataset(metadata, sequences_folder)
CPU times: user 178 ms, sys: 138 ms, total: 315 ms Wall time: 316 ms
y_balanced, balanced_metadata = make_dataset_balanced(metadata, n_entries_per_class=n_entries_per_class)
balanced_sequences = load_rna_nucleotides_dataset(balanced_metadata, sequences_folder)
y = one_hot_encode_classes(y_balanced, classes)
x = sequence_embedding(balanced_sequences, alphabet)
x_train, y_train, x_test, y_test, train_idx, test_idx = split_train_test_set(
x, y, test_ratio=0.2, return_indices=True)
loss, accuracy = model.evaluate(x_test, y_test, verbose=0)
print(f'Loss: {loss:.4f}')
print(f'Accuracy: {accuracy:.4f}')
Loss: 1.2486 Accuracy: 0.4098
y_pred = model(x_test)
labels = [np.argmax(el) for el in y_test]
predictions = [np.argmax(el) for el in y_pred]
conf = tf.math.confusion_matrix(labels, predictions).numpy()
conf
array([[17, 5, 2, 2], [15, 5, 7, 4], [13, 4, 9, 10], [ 3, 4, 3, 19]], dtype=int32)
df_cm = pd.DataFrame(
conf,
index=[f'{c}: {i}' for i, c in enumerate(classes)],
columns=[f'{i}' for i, c in enumerate(classes)],
)
ax = sns.heatmap(df_cm, cmap="Purples");
ax.set_xlabel('Predictions');
ax.set_ylabel('Labels');
conf_sum = np.sum(conf, axis=1)
for i in range(len(classes)):
acc = 100 * conf[i, i] / conf_sum[i]
print(f'Accuracy for class {classes[i]}: {acc:.1f}%')
Accuracy for class psychrophile: 65.4% Accuracy for class mesophile: 16.1% Accuracy for class thermophile: 25.0% Accuracy for class hyperthermophile: 65.5%
y_pred_t = model(x_train)
labels_t = [np.argmax(el) for el in y_train]
predictions_t = [np.argmax(el) for el in y_pred_t]
conf_t = tf.math.confusion_matrix(labels_t, predictions_t).numpy()
conf_t
array([[91, 14, 15, 7], [61, 26, 22, 13], [41, 17, 31, 28], [16, 4, 19, 85]], dtype=int32)
df_cm_t = pd.DataFrame(
conf_t,
index=[f'{c}: {i}' for i, c in enumerate(classes)],
columns=[f'{i}' for i, c in enumerate(classes)],
)
ax = sns.heatmap(df_cm_t, cmap="Purples");
ax.set_xlabel('Predictions');
ax.set_ylabel('Labels');
conf_sum_t = np.sum(conf_t, axis=1)
for i in range(len(classes)):
acc = 100 * conf_t[i, i] / conf_sum_t[i]
print(f'Accuracy for class {classes[i]}: {acc:.1f}%')
Accuracy for class psychrophile: 71.7% Accuracy for class mesophile: 21.3% Accuracy for class thermophile: 26.5% Accuracy for class hyperthermophile: 68.5%
prior, encoder, decoder, vae = variational_autoencoder(n_inputs=100, encoding_size=2, n_hidden=300)
compile_vae(vae, learning_rate=1e-4)
vae.summary()
Model: "model_3" _________________________________________________________________ Layer (type) Output Shape Param # ================================================================= input_1 (InputLayer) [(None, 100)] 0 _________________________________________________________________ dense_2 (Dense) (None, 300) 30300 _________________________________________________________________ dropout_1 (Dropout) (None, 300) 0 _________________________________________________________________ dense_3 (Dense) (None, 5) 1505 _________________________________________________________________ multivariate_normal_tri_l (M ((None, 2), (None, 2)) 0 _________________________________________________________________ model_2 (Model) (None, 100) 61100 ================================================================= Total params: 92,905 Trainable params: 92,905 Non-trainable params: 0 _________________________________________________________________
layer_name = 'dense'
intermediate_layer_model = tf.keras.Model(
inputs=model.input,
outputs=model.get_layer(layer_name).output,
)
train_logits = intermediate_layer_model(x_train).numpy()
test_logits = intermediate_layer_model(x_test).numpy()
print(train_logits.shape, test_logits.shape)
(490, 100) (122, 100)
do_train_vae = False
vae_model_path = os.path.join(os.getcwd(), 'saved_models', f'seed_{seed}', f'{rna}_vae.h5')
if do_train_vae:
vae.fit(
train_logits, train_logits,
validation_data=(test_logits, test_logits),
batch_size=32,
epochs=200,
)
vae.save(vae_model_path)
else:
vae.load_weights(vae_model_path)
Train on 490 samples, validate on 122 samples Epoch 1/200 490/490 [==============================] - 2s 3ms/sample - loss: 58.2549 - val_loss: 56.7157 Epoch 2/200 490/490 [==============================] - 0s 123us/sample - loss: 56.8497 - val_loss: 55.1827 Epoch 3/200 490/490 [==============================] - 0s 112us/sample - loss: 55.2468 - val_loss: 53.7047 Epoch 4/200 490/490 [==============================] - 0s 103us/sample - loss: 53.2585 - val_loss: 51.1662 Epoch 5/200 490/490 [==============================] - 0s 106us/sample - loss: 51.3165 - val_loss: 48.8007 Epoch 6/200 490/490 [==============================] - 0s 108us/sample - loss: 49.2814 - val_loss: 46.1809 Epoch 7/200 490/490 [==============================] - 0s 329us/sample - loss: 46.2690 - val_loss: 42.9845 Epoch 8/200 490/490 [==============================] - 0s 116us/sample - loss: 43.6229 - val_loss: 38.1457 Epoch 9/200 490/490 [==============================] - 0s 104us/sample - loss: 40.7398 - val_loss: 33.0887 Epoch 10/200 490/490 [==============================] - 0s 108us/sample - loss: 37.2631 - val_loss: 29.5537 Epoch 11/200 490/490 [==============================] - 0s 103us/sample - loss: 34.4358 - val_loss: 27.8969 Epoch 12/200 490/490 [==============================] - 0s 103us/sample - loss: 31.5595 - val_loss: 23.2148 Epoch 13/200 490/490 [==============================] - 0s 106us/sample - loss: 31.2242 - val_loss: 23.2392 Epoch 14/200 490/490 [==============================] - 0s 105us/sample - loss: 28.4324 - val_loss: 18.9278 Epoch 15/200 490/490 [==============================] - 0s 107us/sample - loss: 24.4636 - val_loss: 13.1548 Epoch 16/200 490/490 [==============================] - 0s 103us/sample - loss: 22.9043 - val_loss: 11.4771 Epoch 17/200 490/490 [==============================] - 0s 102us/sample - loss: 21.6936 - val_loss: 10.8232 Epoch 18/200 490/490 [==============================] - 0s 102us/sample - loss: 19.7974 - val_loss: 8.6361 Epoch 19/200 490/490 [==============================] - 0s 101us/sample - loss: 16.6510 - val_loss: 6.7369 Epoch 20/200 490/490 [==============================] - 0s 109us/sample - loss: 15.5726 - val_loss: -0.6339 Epoch 21/200 490/490 [==============================] - 0s 107us/sample - loss: 13.4778 - val_loss: 2.0439 Epoch 22/200 490/490 [==============================] - 0s 107us/sample - loss: 11.9272 - val_loss: -0.0620 Epoch 23/200 490/490 [==============================] - 0s 107us/sample - loss: 9.5565 - val_loss: -4.8654 Epoch 24/200 490/490 [==============================] - 0s 102us/sample - loss: 7.1194 - val_loss: -6.6527 Epoch 25/200 490/490 [==============================] - 0s 105us/sample - loss: 6.6087 - val_loss: -9.3597 Epoch 26/200 490/490 [==============================] - 0s 100us/sample - loss: 6.0305 - val_loss: -5.9535 Epoch 27/200 490/490 [==============================] - 0s 101us/sample - loss: 4.4042 - val_loss: -8.8240 Epoch 28/200 490/490 [==============================] - 0s 103us/sample - loss: 0.7700 - val_loss: -12.5925 Epoch 29/200 490/490 [==============================] - 0s 105us/sample - loss: 1.8688 - val_loss: -15.2392 Epoch 30/200 490/490 [==============================] - 0s 104us/sample - loss: -0.6460 - val_loss: -13.2562 Epoch 31/200 490/490 [==============================] - 0s 102us/sample - loss: -2.7750 - val_loss: -17.6318 Epoch 32/200 490/490 [==============================] - 0s 105us/sample - loss: -4.3662 - val_loss: -21.5033 Epoch 33/200 490/490 [==============================] - 0s 102us/sample - loss: -8.2178 - val_loss: -21.9472 Epoch 34/200 490/490 [==============================] - 0s 102us/sample - loss: -7.4250 - val_loss: -24.6509 Epoch 35/200 490/490 [==============================] - 0s 101us/sample - loss: -10.3292 - val_loss: -28.1767 Epoch 36/200 490/490 [==============================] - 0s 99us/sample - loss: -12.5643 - val_loss: -28.8963 Epoch 37/200 490/490 [==============================] - 0s 102us/sample - loss: -12.4287 - val_loss: -31.5541 Epoch 38/200 490/490 [==============================] - 0s 103us/sample - loss: -15.0302 - val_loss: -33.7000 Epoch 39/200 490/490 [==============================] - 0s 106us/sample - loss: -16.9345 - val_loss: -32.1647 Epoch 40/200 490/490 [==============================] - 0s 103us/sample - loss: -20.0548 - val_loss: -36.9585 Epoch 41/200 490/490 [==============================] - 0s 102us/sample - loss: -21.1428 - val_loss: -39.2205 Epoch 42/200 490/490 [==============================] - 0s 102us/sample - loss: -23.7123 - val_loss: -45.5387 Epoch 43/200 490/490 [==============================] - 0s 103us/sample - loss: -23.6107 - val_loss: -41.9988 Epoch 44/200 490/490 [==============================] - 0s 105us/sample - loss: -25.0719 - val_loss: -45.1953 Epoch 45/200 490/490 [==============================] - 0s 105us/sample - loss: -27.8391 - val_loss: -45.3834 Epoch 46/200 490/490 [==============================] - 0s 104us/sample - loss: -28.0251 - val_loss: -49.5014 Epoch 47/200 490/490 [==============================] - 0s 102us/sample - loss: -31.3876 - val_loss: -52.3265 Epoch 48/200 490/490 [==============================] - 0s 105us/sample - loss: -31.6767 - val_loss: -49.7910 Epoch 49/200 490/490 [==============================] - 0s 106us/sample - loss: -34.2108 - val_loss: -53.6507 Epoch 50/200 490/490 [==============================] - 0s 105us/sample - loss: -35.5722 - val_loss: -56.9792 Epoch 51/200 490/490 [==============================] - 0s 104us/sample - loss: -36.7283 - val_loss: -57.6996 Epoch 52/200 490/490 [==============================] - 0s 102us/sample - loss: -39.9089 - val_loss: -59.0506 Epoch 53/200 490/490 [==============================] - 0s 101us/sample - loss: -41.2488 - val_loss: -62.8583 Epoch 54/200 490/490 [==============================] - 0s 101us/sample - loss: -44.4207 - val_loss: -67.0499 Epoch 55/200 490/490 [==============================] - 0s 100us/sample - loss: -46.9887 - val_loss: -69.3974 Epoch 56/200 490/490 [==============================] - 0s 101us/sample - loss: -46.9228 - val_loss: -71.7580 Epoch 57/200 490/490 [==============================] - 0s 102us/sample - loss: -50.0228 - val_loss: -73.7676 Epoch 58/200 490/490 [==============================] - 0s 101us/sample - loss: -51.6199 - val_loss: -77.1490 Epoch 59/200 490/490 [==============================] - 0s 100us/sample - loss: -53.7842 - val_loss: -78.0302 Epoch 60/200 490/490 [==============================] - 0s 103us/sample - loss: -55.2183 - val_loss: -80.5688 Epoch 61/200 490/490 [==============================] - 0s 107us/sample - loss: -56.5128 - val_loss: -82.7685 Epoch 62/200 490/490 [==============================] - 0s 106us/sample - loss: -58.1951 - val_loss: -85.7291 Epoch 63/200 490/490 [==============================] - 0s 109us/sample - loss: -60.1026 - val_loss: -87.0333 Epoch 64/200 490/490 [==============================] - 0s 102us/sample - loss: -61.0475 - val_loss: -88.9387 Epoch 65/200 490/490 [==============================] - 0s 102us/sample - loss: -62.0748 - val_loss: -91.0895 Epoch 66/200 490/490 [==============================] - 0s 102us/sample - loss: -64.4953 - val_loss: -93.0511 Epoch 67/200 490/490 [==============================] - 0s 101us/sample - loss: -65.7188 - val_loss: -94.5208 Epoch 68/200 490/490 [==============================] - 0s 103us/sample - loss: -65.2695 - val_loss: -94.8583 Epoch 69/200 490/490 [==============================] - 0s 107us/sample - loss: -68.4037 - val_loss: -96.2806 Epoch 70/200 490/490 [==============================] - 0s 106us/sample - loss: -69.2240 - val_loss: -97.0887 Epoch 71/200 490/490 [==============================] - 0s 104us/sample - loss: -70.6956 - val_loss: -100.1058 Epoch 72/200 490/490 [==============================] - 0s 100us/sample - loss: -71.4767 - val_loss: -101.1090 Epoch 73/200 490/490 [==============================] - 0s 103us/sample - loss: -71.0946 - val_loss: -101.4940 Epoch 74/200 490/490 [==============================] - 0s 106us/sample - loss: -73.6751 - val_loss: -100.8141 Epoch 75/200 490/490 [==============================] - 0s 102us/sample - loss: -74.6854 - val_loss: -104.0474 Epoch 76/200 490/490 [==============================] - 0s 106us/sample - loss: -76.2455 - val_loss: -104.1112 Epoch 77/200 490/490 [==============================] - 0s 102us/sample - loss: -77.1413 - val_loss: -107.1581 Epoch 78/200 490/490 [==============================] - 0s 100us/sample - loss: -76.5459 - val_loss: -108.2682 Epoch 79/200 490/490 [==============================] - 0s 101us/sample - loss: -79.1573 - val_loss: -107.8220 Epoch 80/200 490/490 [==============================] - 0s 100us/sample - loss: -80.5106 - val_loss: -108.9979 Epoch 81/200 490/490 [==============================] - 0s 105us/sample - loss: -80.6278 - val_loss: -110.4571 Epoch 82/200 490/490 [==============================] - 0s 103us/sample - loss: -81.2680 - val_loss: -110.5782 Epoch 83/200 490/490 [==============================] - 0s 109us/sample - loss: -82.2121 - val_loss: -112.1137 Epoch 84/200 490/490 [==============================] - 0s 104us/sample - loss: -83.1320 - val_loss: -112.0116 Epoch 85/200 490/490 [==============================] - 0s 102us/sample - loss: -83.8663 - val_loss: -114.3740 Epoch 86/200 490/490 [==============================] - 0s 103us/sample - loss: -86.2281 - val_loss: -114.6586 Epoch 87/200 490/490 [==============================] - 0s 106us/sample - loss: -86.0753 - val_loss: -114.4734 Epoch 88/200 490/490 [==============================] - 0s 104us/sample - loss: -85.7764 - val_loss: -116.4553 Epoch 89/200 490/490 [==============================] - 0s 102us/sample - loss: -88.5228 - val_loss: -117.5456 Epoch 90/200 490/490 [==============================] - 0s 102us/sample - loss: -87.7134 - val_loss: -117.5888 Epoch 91/200 490/490 [==============================] - 0s 106us/sample - loss: -89.4321 - val_loss: -119.0497 Epoch 92/200 490/490 [==============================] - 0s 104us/sample - loss: -90.9143 - val_loss: -120.5796 Epoch 93/200 490/490 [==============================] - 0s 107us/sample - loss: -90.9783 - val_loss: -121.9322 Epoch 94/200 490/490 [==============================] - 0s 102us/sample - loss: -92.4939 - val_loss: -121.7171 Epoch 95/200 490/490 [==============================] - 0s 101us/sample - loss: -92.5744 - val_loss: -122.2218 Epoch 96/200 490/490 [==============================] - 0s 102us/sample - loss: -94.5547 - val_loss: -122.9394 Epoch 97/200 490/490 [==============================] - 0s 102us/sample - loss: -95.3925 - val_loss: -124.7418 Epoch 98/200 490/490 [==============================] - 0s 100us/sample - loss: -96.1184 - val_loss: -125.0492 Epoch 99/200 490/490 [==============================] - 0s 101us/sample - loss: -97.2024 - val_loss: -125.5669 Epoch 100/200 490/490 [==============================] - 0s 105us/sample - loss: -96.7772 - val_loss: -127.7357 Epoch 101/200 490/490 [==============================] - 0s 106us/sample - loss: -100.4223 - val_loss: -128.3184 Epoch 102/200 490/490 [==============================] - 0s 103us/sample - loss: -100.1832 - val_loss: -130.3933 Epoch 103/200 490/490 [==============================] - 0s 101us/sample - loss: -99.7757 - val_loss: -129.4216 Epoch 104/200 490/490 [==============================] - 0s 104us/sample - loss: -101.0800 - val_loss: -130.6260 Epoch 105/200 490/490 [==============================] - 0s 102us/sample - loss: -101.9601 - val_loss: -131.3919 Epoch 106/200 490/490 [==============================] - 0s 106us/sample - loss: -104.3980 - val_loss: -131.8183 Epoch 107/200 490/490 [==============================] - 0s 103us/sample - loss: -105.0376 - val_loss: -133.1500 Epoch 108/200 490/490 [==============================] - 0s 105us/sample - loss: -105.0871 - val_loss: -134.5507 Epoch 109/200 490/490 [==============================] - 0s 105us/sample - loss: -106.6594 - val_loss: -135.4268 Epoch 110/200 490/490 [==============================] - 0s 103us/sample - loss: -107.2030 - val_loss: -135.8475 Epoch 111/200 490/490 [==============================] - 0s 100us/sample - loss: -108.2022 - val_loss: -137.3533 Epoch 112/200 490/490 [==============================] - 0s 102us/sample - loss: -109.4009 - val_loss: -137.5335 Epoch 113/200 490/490 [==============================] - 0s 108us/sample - loss: -109.3289 - val_loss: -139.4717 Epoch 114/200 490/490 [==============================] - 0s 102us/sample - loss: -111.0724 - val_loss: -139.7174 Epoch 115/200 490/490 [==============================] - 0s 103us/sample - loss: -112.2020 - val_loss: -140.9051 Epoch 116/200 490/490 [==============================] - 0s 102us/sample - loss: -112.0657 - val_loss: -141.5068 Epoch 117/200 490/490 [==============================] - 0s 102us/sample - loss: -112.2166 - val_loss: -141.6330 Epoch 118/200 490/490 [==============================] - 0s 102us/sample - loss: -114.1325 - val_loss: -143.4830 Epoch 119/200 490/490 [==============================] - 0s 102us/sample - loss: -116.1307 - val_loss: -145.4713 Epoch 120/200 490/490 [==============================] - 0s 104us/sample - loss: -115.6588 - val_loss: -145.4699 Epoch 121/200 490/490 [==============================] - 0s 106us/sample - loss: -117.1935 - val_loss: -143.3129 Epoch 122/200 490/490 [==============================] - 0s 107us/sample - loss: -117.9550 - val_loss: -145.7760 Epoch 123/200 490/490 [==============================] - 0s 102us/sample - loss: -119.2554 - val_loss: -148.3044 Epoch 124/200 490/490 [==============================] - 0s 102us/sample - loss: -121.1433 - val_loss: -149.5166 Epoch 125/200 490/490 [==============================] - 0s 105us/sample - loss: -120.4011 - val_loss: -150.3376 Epoch 126/200 490/490 [==============================] - 0s 104us/sample - loss: -122.9857 - val_loss: -152.2797 Epoch 127/200 490/490 [==============================] - 0s 107us/sample - loss: -123.9435 - val_loss: -153.1957 Epoch 128/200 490/490 [==============================] - 0s 100us/sample - loss: -123.1900 - val_loss: -152.8510 Epoch 129/200 490/490 [==============================] - 0s 100us/sample - loss: -124.2502 - val_loss: -154.0886 Epoch 130/200 490/490 [==============================] - 0s 101us/sample - loss: -126.6577 - val_loss: -154.5831 Epoch 131/200 490/490 [==============================] - 0s 103us/sample - loss: -128.4674 - val_loss: -156.1843 Epoch 132/200 490/490 [==============================] - 0s 107us/sample - loss: -128.5894 - val_loss: -156.4045 Epoch 133/200 490/490 [==============================] - 0s 103us/sample - loss: -129.4528 - val_loss: -158.0526 Epoch 134/200 490/490 [==============================] - 0s 110us/sample - loss: -131.7120 - val_loss: -159.7845 Epoch 135/200 490/490 [==============================] - 0s 102us/sample - loss: -131.1872 - val_loss: -159.8304 Epoch 136/200 490/490 [==============================] - 0s 106us/sample - loss: -132.9404 - val_loss: -160.5565 Epoch 137/200 490/490 [==============================] - 0s 105us/sample - loss: -134.1048 - val_loss: -162.6317 Epoch 138/200 490/490 [==============================] - 0s 104us/sample - loss: -135.2387 - val_loss: -162.3831 Epoch 139/200 490/490 [==============================] - 0s 102us/sample - loss: -135.7521 - val_loss: -163.6444 Epoch 140/200 490/490 [==============================] - 0s 102us/sample - loss: -137.9141 - val_loss: -165.7497 Epoch 141/200 490/490 [==============================] - 0s 105us/sample - loss: -138.1648 - val_loss: -166.7695 Epoch 142/200 490/490 [==============================] - 0s 108us/sample - loss: -138.7788 - val_loss: -166.7127 Epoch 143/200 490/490 [==============================] - 0s 105us/sample - loss: -140.5276 - val_loss: -167.5473 Epoch 144/200 490/490 [==============================] - 0s 102us/sample - loss: -141.0014 - val_loss: -167.7474 Epoch 145/200 490/490 [==============================] - 0s 101us/sample - loss: -141.8365 - val_loss: -170.4845 Epoch 146/200 490/490 [==============================] - 0s 105us/sample - loss: -143.4147 - val_loss: -170.9549 Epoch 147/200 490/490 [==============================] - 0s 103us/sample - loss: -144.7242 - val_loss: -172.3402 Epoch 148/200 490/490 [==============================] - 0s 105us/sample - loss: -143.6631 - val_loss: -171.7064 Epoch 149/200 490/490 [==============================] - 0s 106us/sample - loss: -146.3220 - val_loss: -173.2560 Epoch 150/200 490/490 [==============================] - 0s 111us/sample - loss: -149.3122 - val_loss: -175.7649 Epoch 151/200 490/490 [==============================] - 0s 104us/sample - loss: -148.3433 - val_loss: -176.6339 Epoch 152/200 490/490 [==============================] - 0s 103us/sample - loss: -151.8590 - val_loss: -178.0861 Epoch 153/200 490/490 [==============================] - 0s 105us/sample - loss: -152.5494 - val_loss: -178.9435 Epoch 154/200 490/490 [==============================] - 0s 105us/sample - loss: -151.4835 - val_loss: -180.1916 Epoch 155/200 490/490 [==============================] - 0s 104us/sample - loss: -153.5866 - val_loss: -181.3360 Epoch 156/200 490/490 [==============================] - 0s 102us/sample - loss: -153.9989 - val_loss: -182.5128 Epoch 157/200 490/490 [==============================] - 0s 105us/sample - loss: -155.3681 - val_loss: -183.4963 Epoch 158/200 490/490 [==============================] - 0s 103us/sample - loss: -157.5636 - val_loss: -185.4517 Epoch 159/200 490/490 [==============================] - 0s 106us/sample - loss: -159.5131 - val_loss: -185.2412 Epoch 160/200 490/490 [==============================] - 0s 106us/sample - loss: -159.6434 - val_loss: -186.9401 Epoch 161/200 490/490 [==============================] - 0s 106us/sample - loss: -160.5409 - val_loss: -187.2404 Epoch 162/200 490/490 [==============================] - 0s 100us/sample - loss: -161.9517 - val_loss: -186.6193 Epoch 163/200 490/490 [==============================] - 0s 105us/sample - loss: -164.1756 - val_loss: -189.6367 Epoch 164/200 490/490 [==============================] - 0s 102us/sample - loss: -164.4166 - val_loss: -189.4857 Epoch 165/200 490/490 [==============================] - 0s 100us/sample - loss: -165.1206 - val_loss: -190.0585 Epoch 166/200 490/490 [==============================] - 0s 102us/sample - loss: -164.9807 - val_loss: -192.2357 Epoch 167/200 490/490 [==============================] - 0s 102us/sample - loss: -167.1444 - val_loss: -192.9076 Epoch 168/200 490/490 [==============================] - 0s 103us/sample - loss: -169.1566 - val_loss: -193.9785 Epoch 169/200 490/490 [==============================] - 0s 107us/sample - loss: -168.6662 - val_loss: -194.7427 Epoch 170/200 490/490 [==============================] - 0s 106us/sample - loss: -168.8785 - val_loss: -194.5141 Epoch 171/200 490/490 [==============================] - 0s 124us/sample - loss: -173.9531 - val_loss: -195.7512 Epoch 172/200 490/490 [==============================] - 0s 116us/sample - loss: -175.1776 - val_loss: -199.4728 Epoch 173/200 490/490 [==============================] - 0s 107us/sample - loss: -176.1920 - val_loss: -199.2749 Epoch 174/200 490/490 [==============================] - 0s 104us/sample - loss: -176.2494 - val_loss: -201.1033 Epoch 175/200 490/490 [==============================] - 0s 102us/sample - loss: -174.8692 - val_loss: -198.3330 Epoch 176/200 490/490 [==============================] - 0s 101us/sample - loss: -177.0281 - val_loss: -202.0200 Epoch 177/200 490/490 [==============================] - 0s 100us/sample - loss: -180.9320 - val_loss: -204.5076 Epoch 178/200 490/490 [==============================] - 0s 101us/sample - loss: -181.3766 - val_loss: -206.6752 Epoch 179/200 490/490 [==============================] - 0s 101us/sample - loss: -181.2788 - val_loss: -204.3359 Epoch 180/200 490/490 [==============================] - 0s 102us/sample - loss: -181.7512 - val_loss: -204.4141 Epoch 181/200 490/490 [==============================] - 0s 103us/sample - loss: -184.3259 - val_loss: -206.2128 Epoch 182/200 490/490 [==============================] - 0s 103us/sample - loss: -184.7561 - val_loss: -207.5157 Epoch 183/200 490/490 [==============================] - 0s 104us/sample - loss: -182.8469 - val_loss: -207.2557 Epoch 184/200 490/490 [==============================] - 0s 104us/sample - loss: -185.1542 - val_loss: -207.6619 Epoch 185/200 490/490 [==============================] - 0s 100us/sample - loss: -186.5792 - val_loss: -208.4623 Epoch 186/200 490/490 [==============================] - 0s 105us/sample - loss: -186.7335 - val_loss: -211.1515 Epoch 187/200 490/490 [==============================] - 0s 107us/sample - loss: -188.3163 - val_loss: -211.9790 Epoch 188/200 490/490 [==============================] - 0s 107us/sample - loss: -188.4503 - val_loss: -214.2250 Epoch 189/200 490/490 [==============================] - 0s 104us/sample - loss: -190.8792 - val_loss: -213.9178 Epoch 190/200 490/490 [==============================] - 0s 100us/sample - loss: -189.6331 - val_loss: -213.1729 Epoch 191/200 490/490 [==============================] - 0s 101us/sample - loss: -192.1934 - val_loss: -215.7490 Epoch 192/200 490/490 [==============================] - 0s 100us/sample - loss: -194.1032 - val_loss: -216.4105 Epoch 193/200 490/490 [==============================] - 0s 106us/sample - loss: -196.6470 - val_loss: -216.0855 Epoch 194/200 490/490 [==============================] - 0s 101us/sample - loss: -196.1780 - val_loss: -216.1468 Epoch 195/200 490/490 [==============================] - 0s 105us/sample - loss: -196.3755 - val_loss: -217.5604 Epoch 196/200 490/490 [==============================] - 0s 104us/sample - loss: -196.9007 - val_loss: -219.5931 Epoch 197/200 490/490 [==============================] - 0s 105us/sample - loss: -194.9834 - val_loss: -219.2451 Epoch 198/200 490/490 [==============================] - 0s 106us/sample - loss: -198.8770 - val_loss: -221.4131 Epoch 199/200 490/490 [==============================] - 0s 106us/sample - loss: -192.7698 - val_loss: -217.0143 Epoch 200/200 490/490 [==============================] - 0s 103us/sample - loss: -191.9163 - val_loss: -212.9696
encoding_test_dist = encoder(test_logits)
encoding_test = encoding_test_dist.mean().numpy()
encoding_test.shape
(122, 2)
temp_test = balanced_metadata.iloc[test_idx]['temp'].values
temp_test
array([15, 10, 70, 28, 70, 55, 55, 75, 75, 70, 26, 15, 10, 75, 55, 27, 30, 37, 80, 70, 65, 10, 80, 12, 55, 30, 37, 26, 37, 25, 55, 15, 26, 75, 75, 37, 55, 75, 37, 85, 80, 30, 25, 80, 10, 61, 30, 50, 28, 75, 41, 65, 70, 75, 4, 55, 55, 12, 70, 30, 37, 4, 88, 60, 75, 10, 4, 65, 10, 60, 22, 75, 95, 10, 4, 10, 90, 22, 4, 55, 10, 37, 37, 55, 95, 4, 75, 75, 55, 65, 45, 75, 95, 30, 65, 60, 26, 70, 4, 26, 75, 15, 56, 70, 55, 55, 80, 61, 37, 45, 65, 75, 28, 10, 80, 80, 55, 34, 75, 65, 28, 12])
def plot_encoding(encoding, labels, figsize=(12, 6), show_only=None):
f, ax = plt.subplots(1, 1, figsize=figsize)
palette = sns.color_palette()
for i, code in enumerate(encoding):
if show_only is not None and labels[i] not in show_only:
continue
ax.plot(code[0], code[1], 'o', color=palette[labels[i]])
return f, ax
plot_encoding(encoding_test, labels, show_only=(0, 3));
def plot_encoding_with_colors(encoding, labels, color_vals, figsize=(12, 6), show_only=None):
f, ax = plt.subplots(1, 1, figsize=figsize)
palette = sns.color_palette()
x_d, y_d = [], []
colors = []
for i, code in enumerate(encoding):
if show_only is not None and labels[i] not in show_only:
continue
x_d.append(encoding[i,0])
y_d.append(encoding[i,1])
colors.append(color_vals[i])
sc = ax.scatter(x_d, y_d, c=colors, cmap='YlOrRd')
plt.colorbar(sc)
return f, ax
plot_encoding_with_colors(encoding_test, labels, temp_test, show_only=(0, 3));