Large arrays of radio antennas can be used to measure cosmic rays by recording the electromagnetic radiation generated in the atmosphere. These radio signals are strongly contaminated by galactic noise as well as signals from human origin. Since these signals appear to be similar to the background, the discovery of cosmic-ray events can be challenging.
In this exercise, we design an RNN to classify if the recorded radio signals contain a cosmic-ray event or only noise.
The signal-to-noise ratio (SNR) of a measured trace $S(t)$ is defined as follows:
$$\mathrm{SNR}=\frac{S^{\mathrm{signal}}(t)_\mathrm{max}}{\mathrm{RMS}[S(t)]},$$where $S^{\mathrm{signal}}(t)_\mathrm{max}$ denotes the maximum amplitude of the (true) signal.
Typical cosmic-ray observatories enable a precise reconstruction at an SNR of roughly 3.
We choose a challenging setup in this task and try to identify cosmic-ray events in signal traces with an SNR of 2.
Training RNNs can be computationally demanding, thus, we recommend to use a GPU for this task.
import tensorflow as tf
from tensorflow import keras
import numpy as np
import matplotlib.pyplot as plt
layers = keras.layers
print("keras", keras.__version__)
keras 2.4.0
In this task, we use a simulation of cosmic-ray-induced air showers that are measured by radion antennas.
For more information, see https://arxiv.org/abs/1901.04079.
The task is to design an RNN which is able to identify if the measured signal traces (shortened to 500 time steps) contains a signal or not.
import os
import gdown
url = "https://drive.google.com/u/0/uc?export=download&confirm=HgGH&id=1R-qfxO1jVh88TC9Gnm9JGMomSRg0Zpkx"
output = 'radio_data.npz'
if os.path.exists(output) == False:
gdown.download(url, output, quiet=True)
f = np.load(output)
n_train = 40000
x_train, x_test = f["traces"][:n_train], f["traces"][n_train:] # measured traces (signal + colored noise)
signals = f["signals"] # signal part (only available for cosmic-ray events)
labels = (signals.std(axis=-1) != 0).astype(float) # define training label (1=cosmic event, 0=noise)
y_train, y_test = labels[:n_train], labels[n_train:]
Left: signal trace containing a cosmic-ray event. The underlying cosmic-ray signal is shown in red, the backgrounds + signal is shown in blue. Right: background noise.
from matplotlib import pyplot as plt
fs = 180e6 # Sampling frequency of antenna setup 180 MHz
t = np.arange(500) / fs * 1e6
idx = np.random.randint(0, labels.sum()-1)
idx2 = np.random.randint(0, n_train - labels.sum())
plt.figure(1, (12, 4))
plt.subplot(1, 2, 1)
plt.plot(t, np.real(f["traces"][labels.astype(bool)][idx]), linewidth = 1, color="b", label="Measured trace")
plt.plot(t, np.real(signals[labels.astype(bool)][idx]), linewidth = 1, color="r", label="CR signal")
plt.ylabel('Amplitude / mV')
plt.xlabel('Time / $\mu \mathrm{s}$')
plt.legend()
plt.title("Cosmic-ray event")
plt.subplot(1, 2, 2)
plt.plot(t, np.real(x_train[~y_train.astype(bool)][idx2]), linewidth = 1, color="b", label="Measured trace")
plt.ylabel('Amplitude / mV')
plt.xlabel('Time / $\mu \mathrm{s}$')
plt.legend()
plt.title("Noise event")
plt.grid(True)
plt.tight_layout()
sigma = x_train.std()
x_train /= sigma
x_test /= sigma
In the following, design a cosmic-ray model to identify cosmic-ray events using an RNN-based classifier.
model = keras.models.Sequential()
model.add(layers.Bidirectional(layers.LSTM(32, return_sequences=True), input_shape=(500,1)))
model.add(layers.Bidirectional(layers.LSTM(64, return_sequences=True)))
model.add(layers.Bidirectional(layers.LSTM(10, return_sequences=True)))
model.add(layers.Flatten())
model.add(layers.Dropout(0.3))
model.add(layers.Dense(1, activation="sigmoid"))
model.summary()
Model: "sequential" _________________________________________________________________ Layer (type) Output Shape Param # ================================================================= bidirectional (Bidirectional (None, 500, 64) 8704 _________________________________________________________________ bidirectional_1 (Bidirection (None, 500, 128) 66048 _________________________________________________________________ bidirectional_2 (Bidirection (None, 500, 20) 11120 _________________________________________________________________ flatten (Flatten) (None, 10000) 0 _________________________________________________________________ dropout (Dropout) (None, 10000) 0 _________________________________________________________________ dense (Dense) (None, 1) 10001 ================================================================= Total params: 95,873 Trainable params: 95,873 Non-trainable params: 0 _________________________________________________________________
model.compile(
loss='binary_crossentropy',
optimizer=keras.optimizers.Adam(1e-3, decay=0.00008),
metrics=['accuracy'])
results = model.fit(x_train[...,np.newaxis], y_train,
batch_size=128,
epochs=100,
verbose=1,
validation_split=0.1,
callbacks = [keras.callbacks.ReduceLROnPlateau(factor=0.5, patience=5, verbose=1, min_lr=1e-5),
keras.callbacks.EarlyStopping(patience=15, verbose=1)]
)
Epoch 1/100 282/282 [==============================] - 42s 119ms/step - loss: 0.6340 - accuracy: 0.6387 - val_loss: 0.5846 - val_accuracy: 0.6823 Epoch 2/100 282/282 [==============================] - 32s 112ms/step - loss: 0.5897 - accuracy: 0.6818 - val_loss: 0.5826 - val_accuracy: 0.6845 Epoch 3/100 282/282 [==============================] - 32s 113ms/step - loss: 0.5875 - accuracy: 0.6848 - val_loss: 0.5851 - val_accuracy: 0.6892 Epoch 4/100 282/282 [==============================] - 32s 114ms/step - loss: 0.5857 - accuracy: 0.6859 - val_loss: 0.5800 - val_accuracy: 0.6848 Epoch 5/100 282/282 [==============================] - 32s 113ms/step - loss: 0.5858 - accuracy: 0.6824 - val_loss: 0.6100 - val_accuracy: 0.6702 Epoch 6/100 282/282 [==============================] - 32s 113ms/step - loss: 0.5835 - accuracy: 0.6849 - val_loss: 0.5897 - val_accuracy: 0.6783 Epoch 7/100 282/282 [==============================] - 32s 113ms/step - loss: 0.5854 - accuracy: 0.6858 - val_loss: 0.5785 - val_accuracy: 0.6892 Epoch 8/100 282/282 [==============================] - 32s 114ms/step - loss: 0.5829 - accuracy: 0.6884 - val_loss: 0.5816 - val_accuracy: 0.6862 Epoch 9/100 282/282 [==============================] - 32s 114ms/step - loss: 0.5813 - accuracy: 0.6898 - val_loss: 0.5961 - val_accuracy: 0.6787 Epoch 10/100 282/282 [==============================] - 32s 114ms/step - loss: 0.5830 - accuracy: 0.6883 - val_loss: 0.5809 - val_accuracy: 0.6865 Epoch 11/100 282/282 [==============================] - 32s 115ms/step - loss: 0.5805 - accuracy: 0.6916 - val_loss: 0.5817 - val_accuracy: 0.6880 Epoch 12/100 282/282 [==============================] - 32s 113ms/step - loss: 0.5797 - accuracy: 0.6904 - val_loss: 0.5816 - val_accuracy: 0.6835 Epoch 00012: ReduceLROnPlateau reducing learning rate to 0.0005000000237487257. Epoch 13/100 282/282 [==============================] - 32s 114ms/step - loss: 0.5750 - accuracy: 0.6952 - val_loss: 0.5822 - val_accuracy: 0.6842 Epoch 14/100 282/282 [==============================] - 32s 114ms/step - loss: 0.5751 - accuracy: 0.6959 - val_loss: 0.5899 - val_accuracy: 0.6810 Epoch 15/100 282/282 [==============================] - 32s 114ms/step - loss: 0.5748 - accuracy: 0.6939 - val_loss: 0.5907 - val_accuracy: 0.6798 Epoch 16/100 282/282 [==============================] - 32s 114ms/step - loss: 0.5753 - accuracy: 0.6951 - val_loss: 0.5863 - val_accuracy: 0.6840 Epoch 17/100 282/282 [==============================] - 32s 114ms/step - loss: 0.5751 - accuracy: 0.6964 - val_loss: 0.5805 - val_accuracy: 0.6888 Epoch 00017: ReduceLROnPlateau reducing learning rate to 0.0002500000118743628. Epoch 18/100 282/282 [==============================] - 32s 114ms/step - loss: 0.5711 - accuracy: 0.6990 - val_loss: 0.5811 - val_accuracy: 0.6852 Epoch 19/100 282/282 [==============================] - 32s 114ms/step - loss: 0.5713 - accuracy: 0.6986 - val_loss: 0.5826 - val_accuracy: 0.6875 Epoch 20/100 282/282 [==============================] - 32s 113ms/step - loss: 0.5698 - accuracy: 0.6998 - val_loss: 0.5816 - val_accuracy: 0.6837 Epoch 21/100 282/282 [==============================] - 32s 113ms/step - loss: 0.5692 - accuracy: 0.6988 - val_loss: 0.5863 - val_accuracy: 0.6852 Epoch 22/100 282/282 [==============================] - 32s 113ms/step - loss: 0.5689 - accuracy: 0.7021 - val_loss: 0.5779 - val_accuracy: 0.6877 Epoch 23/100 282/282 [==============================] - 32s 113ms/step - loss: 0.5624 - accuracy: 0.7029 - val_loss: 0.5693 - val_accuracy: 0.6982 Epoch 24/100 282/282 [==============================] - 32s 113ms/step - loss: 0.5567 - accuracy: 0.7089 - val_loss: 0.5639 - val_accuracy: 0.7030 Epoch 25/100 282/282 [==============================] - 32s 114ms/step - loss: 0.5523 - accuracy: 0.7124 - val_loss: 0.5647 - val_accuracy: 0.7020 Epoch 26/100 282/282 [==============================] - 32s 113ms/step - loss: 0.5469 - accuracy: 0.7161 - val_loss: 0.5509 - val_accuracy: 0.7138 Epoch 27/100 282/282 [==============================] - 32s 114ms/step - loss: 0.5368 - accuracy: 0.7208 - val_loss: 0.5293 - val_accuracy: 0.7272 Epoch 28/100 282/282 [==============================] - 32s 113ms/step - loss: 0.5397 - accuracy: 0.7194 - val_loss: 0.5640 - val_accuracy: 0.7000 Epoch 29/100 282/282 [==============================] - 32s 114ms/step - loss: 0.5312 - accuracy: 0.7279 - val_loss: 0.5275 - val_accuracy: 0.7247 Epoch 30/100 282/282 [==============================] - 32s 113ms/step - loss: 0.5030 - accuracy: 0.7456 - val_loss: 0.4799 - val_accuracy: 0.7620 Epoch 31/100 282/282 [==============================] - 32s 113ms/step - loss: 0.4591 - accuracy: 0.7753 - val_loss: 0.3993 - val_accuracy: 0.8167 Epoch 32/100 282/282 [==============================] - 32s 113ms/step - loss: 0.4390 - accuracy: 0.7870 - val_loss: 0.3567 - val_accuracy: 0.8440 Epoch 33/100 282/282 [==============================] - 32s 113ms/step - loss: 0.4531 - accuracy: 0.7831 - val_loss: 0.4448 - val_accuracy: 0.7890 Epoch 34/100 282/282 [==============================] - 32s 113ms/step - loss: 0.3567 - accuracy: 0.8382 - val_loss: 0.4410 - val_accuracy: 0.7857 Epoch 35/100 282/282 [==============================] - 32s 113ms/step - loss: 0.3203 - accuracy: 0.8605 - val_loss: 0.2384 - val_accuracy: 0.9038 Epoch 36/100 282/282 [==============================] - 32s 113ms/step - loss: 0.2259 - accuracy: 0.9109 - val_loss: 0.2173 - val_accuracy: 0.9137 Epoch 37/100 282/282 [==============================] - 32s 112ms/step - loss: 0.5097 - accuracy: 0.7615 - val_loss: 0.5205 - val_accuracy: 0.7337 Epoch 38/100 282/282 [==============================] - 32s 113ms/step - loss: 0.4893 - accuracy: 0.7575 - val_loss: 0.4406 - val_accuracy: 0.7860 Epoch 39/100 282/282 [==============================] - 32s 113ms/step - loss: 0.4003 - accuracy: 0.8154 - val_loss: 0.4422 - val_accuracy: 0.7965 Epoch 40/100 282/282 [==============================] - 32s 112ms/step - loss: 0.3712 - accuracy: 0.8314 - val_loss: 0.3695 - val_accuracy: 0.8388 Epoch 41/100 282/282 [==============================] - 32s 113ms/step - loss: 0.3175 - accuracy: 0.8637 - val_loss: 0.3010 - val_accuracy: 0.8750 Epoch 00041: ReduceLROnPlateau reducing learning rate to 0.0001250000059371814. Epoch 42/100 282/282 [==============================] - 32s 112ms/step - loss: 0.2289 - accuracy: 0.9109 - val_loss: 0.2339 - val_accuracy: 0.9020 Epoch 43/100 282/282 [==============================] - 32s 113ms/step - loss: 0.1981 - accuracy: 0.9239 - val_loss: 0.1823 - val_accuracy: 0.9312 Epoch 44/100 282/282 [==============================] - 32s 112ms/step - loss: 0.1666 - accuracy: 0.9381 - val_loss: 0.1497 - val_accuracy: 0.9433 Epoch 45/100 282/282 [==============================] - 32s 113ms/step - loss: 0.1438 - accuracy: 0.9469 - val_loss: 0.1193 - val_accuracy: 0.9592 Epoch 46/100 282/282 [==============================] - 32s 113ms/step - loss: 0.1368 - accuracy: 0.9502 - val_loss: 0.1190 - val_accuracy: 0.9572 Epoch 47/100 282/282 [==============================] - 32s 113ms/step - loss: 0.1219 - accuracy: 0.9573 - val_loss: 0.1381 - val_accuracy: 0.9460 Epoch 48/100 282/282 [==============================] - 32s 113ms/step - loss: 0.1132 - accuracy: 0.9605 - val_loss: 0.1858 - val_accuracy: 0.9273 Epoch 49/100 282/282 [==============================] - 32s 112ms/step - loss: 0.1147 - accuracy: 0.9595 - val_loss: 0.0818 - val_accuracy: 0.9720 Epoch 50/100 282/282 [==============================] - 32s 113ms/step - loss: 0.1035 - accuracy: 0.9638 - val_loss: 0.0860 - val_accuracy: 0.9707 Epoch 51/100 282/282 [==============================] - 32s 112ms/step - loss: 0.1018 - accuracy: 0.9651 - val_loss: 0.0693 - val_accuracy: 0.9785 Epoch 52/100 282/282 [==============================] - 32s 113ms/step - loss: 0.0805 - accuracy: 0.9726 - val_loss: 0.0606 - val_accuracy: 0.9803 Epoch 53/100 282/282 [==============================] - 32s 113ms/step - loss: 0.1922 - accuracy: 0.9281 - val_loss: 0.0823 - val_accuracy: 0.9768 Epoch 54/100 282/282 [==============================] - 32s 113ms/step - loss: 0.0765 - accuracy: 0.9749 - val_loss: 0.0664 - val_accuracy: 0.9787 Epoch 55/100 282/282 [==============================] - 32s 113ms/step - loss: 0.0691 - accuracy: 0.9771 - val_loss: 0.0719 - val_accuracy: 0.9778 Epoch 56/100 282/282 [==============================] - 32s 113ms/step - loss: 0.0699 - accuracy: 0.9770 - val_loss: 0.0511 - val_accuracy: 0.9845 Epoch 57/100 282/282 [==============================] - 32s 114ms/step - loss: 0.0797 - accuracy: 0.9731 - val_loss: 0.2559 - val_accuracy: 0.8903 Epoch 58/100 282/282 [==============================] - 32s 113ms/step - loss: 0.0792 - accuracy: 0.9739 - val_loss: 0.0531 - val_accuracy: 0.9833 Epoch 59/100 282/282 [==============================] - 32s 113ms/step - loss: 0.0639 - accuracy: 0.9793 - val_loss: 0.0669 - val_accuracy: 0.9820 Epoch 60/100 282/282 [==============================] - 32s 113ms/step - loss: 0.0652 - accuracy: 0.9791 - val_loss: 0.0574 - val_accuracy: 0.9790 Epoch 61/100 282/282 [==============================] - 32s 113ms/step - loss: 0.0591 - accuracy: 0.9812 - val_loss: 0.0805 - val_accuracy: 0.9820 Epoch 00061: ReduceLROnPlateau reducing learning rate to 6.25000029685907e-05. Epoch 62/100 282/282 [==============================] - 32s 113ms/step - loss: 0.0427 - accuracy: 0.9866 - val_loss: 0.0608 - val_accuracy: 0.9803 Epoch 63/100 282/282 [==============================] - 32s 113ms/step - loss: 0.0414 - accuracy: 0.9868 - val_loss: 0.0434 - val_accuracy: 0.9893 Epoch 64/100 282/282 [==============================] - 32s 113ms/step - loss: 0.0424 - accuracy: 0.9862 - val_loss: 0.0377 - val_accuracy: 0.9918 Epoch 65/100 282/282 [==============================] - 32s 113ms/step - loss: 0.0424 - accuracy: 0.9859 - val_loss: 0.0402 - val_accuracy: 0.9872 Epoch 66/100 282/282 [==============================] - 32s 113ms/step - loss: 0.0430 - accuracy: 0.9868 - val_loss: 0.0379 - val_accuracy: 0.9883 Epoch 67/100 282/282 [==============================] - 32s 114ms/step - loss: 0.0412 - accuracy: 0.9865 - val_loss: 0.0324 - val_accuracy: 0.9900 Epoch 68/100 282/282 [==============================] - 32s 113ms/step - loss: 0.0383 - accuracy: 0.9883 - val_loss: 0.0414 - val_accuracy: 0.9890 Epoch 69/100 282/282 [==============================] - 32s 113ms/step - loss: 0.0394 - accuracy: 0.9878 - val_loss: 0.0321 - val_accuracy: 0.9908 Epoch 70/100 282/282 [==============================] - 32s 112ms/step - loss: 0.0350 - accuracy: 0.9892 - val_loss: 0.0372 - val_accuracy: 0.9887 Epoch 71/100 282/282 [==============================] - 32s 113ms/step - loss: 0.0379 - accuracy: 0.9879 - val_loss: 0.0342 - val_accuracy: 0.9893 Epoch 72/100 282/282 [==============================] - 32s 113ms/step - loss: 0.0532 - accuracy: 0.9840 - val_loss: 0.0338 - val_accuracy: 0.9902 Epoch 73/100 282/282 [==============================] - 32s 113ms/step - loss: 0.0390 - accuracy: 0.9879 - val_loss: 0.0338 - val_accuracy: 0.9910 Epoch 74/100 282/282 [==============================] - 32s 114ms/step - loss: 0.0392 - accuracy: 0.9875 - val_loss: 0.0455 - val_accuracy: 0.9885 Epoch 00074: ReduceLROnPlateau reducing learning rate to 3.125000148429535e-05. Epoch 75/100 282/282 [==============================] - 32s 113ms/step - loss: 0.0332 - accuracy: 0.9896 - val_loss: 0.0285 - val_accuracy: 0.9933 Epoch 76/100 282/282 [==============================] - 32s 113ms/step - loss: 0.0301 - accuracy: 0.9906 - val_loss: 0.0382 - val_accuracy: 0.9915 Epoch 77/100 282/282 [==============================] - 32s 113ms/step - loss: 0.0296 - accuracy: 0.9910 - val_loss: 0.0328 - val_accuracy: 0.9900 Epoch 78/100 282/282 [==============================] - 32s 113ms/step - loss: 0.0297 - accuracy: 0.9910 - val_loss: 0.0326 - val_accuracy: 0.9898 Epoch 79/100 282/282 [==============================] - 32s 113ms/step - loss: 0.0312 - accuracy: 0.9907 - val_loss: 0.0296 - val_accuracy: 0.9935 Epoch 80/100 282/282 [==============================] - 32s 113ms/step - loss: 0.0285 - accuracy: 0.9913 - val_loss: 0.0797 - val_accuracy: 0.9775 Epoch 00080: ReduceLROnPlateau reducing learning rate to 1.5625000742147677e-05. Epoch 81/100 282/282 [==============================] - 32s 114ms/step - loss: 0.0264 - accuracy: 0.9922 - val_loss: 0.0282 - val_accuracy: 0.9937 Epoch 82/100 282/282 [==============================] - 32s 113ms/step - loss: 0.0254 - accuracy: 0.9923 - val_loss: 0.0277 - val_accuracy: 0.9935 Epoch 83/100 282/282 [==============================] - 32s 113ms/step - loss: 0.0264 - accuracy: 0.9918 - val_loss: 0.0257 - val_accuracy: 0.9933 Epoch 84/100 282/282 [==============================] - 32s 113ms/step - loss: 0.0255 - accuracy: 0.9921 - val_loss: 0.0284 - val_accuracy: 0.9937 Epoch 85/100 282/282 [==============================] - 32s 113ms/step - loss: 0.0248 - accuracy: 0.9923 - val_loss: 0.0283 - val_accuracy: 0.9933 Epoch 86/100 282/282 [==============================] - 32s 113ms/step - loss: 0.0242 - accuracy: 0.9927 - val_loss: 0.0282 - val_accuracy: 0.9920 Epoch 87/100 282/282 [==============================] - 32s 113ms/step - loss: 0.0251 - accuracy: 0.9921 - val_loss: 0.0273 - val_accuracy: 0.9925 Epoch 88/100 282/282 [==============================] - 32s 113ms/step - loss: 0.0246 - accuracy: 0.9927 - val_loss: 0.0349 - val_accuracy: 0.9923 Epoch 00088: ReduceLROnPlateau reducing learning rate to 1e-05. Epoch 89/100 282/282 [==============================] - 32s 113ms/step - loss: 0.0236 - accuracy: 0.9931 - val_loss: 0.0277 - val_accuracy: 0.9925 Epoch 90/100 282/282 [==============================] - 32s 113ms/step - loss: 0.0237 - accuracy: 0.9928 - val_loss: 0.0260 - val_accuracy: 0.9935 Epoch 91/100 282/282 [==============================] - 32s 113ms/step - loss: 0.0242 - accuracy: 0.9928 - val_loss: 0.0257 - val_accuracy: 0.9940 Epoch 92/100 282/282 [==============================] - 32s 113ms/step - loss: 0.0242 - accuracy: 0.9926 - val_loss: 0.0258 - val_accuracy: 0.9942 Epoch 93/100 282/282 [==============================] - 32s 113ms/step - loss: 0.0239 - accuracy: 0.9928 - val_loss: 0.0253 - val_accuracy: 0.9942 Epoch 94/100 282/282 [==============================] - 32s 113ms/step - loss: 0.0237 - accuracy: 0.9930 - val_loss: 0.0245 - val_accuracy: 0.9935 Epoch 95/100 282/282 [==============================] - 32s 113ms/step - loss: 0.0234 - accuracy: 0.9931 - val_loss: 0.0249 - val_accuracy: 0.9940 Epoch 96/100 282/282 [==============================] - 32s 114ms/step - loss: 0.0235 - accuracy: 0.9931 - val_loss: 0.0247 - val_accuracy: 0.9935 Epoch 97/100 282/282 [==============================] - 32s 114ms/step - loss: 0.0234 - accuracy: 0.9927 - val_loss: 0.0251 - val_accuracy: 0.9930 Epoch 98/100 282/282 [==============================] - 32s 113ms/step - loss: 0.0234 - accuracy: 0.9931 - val_loss: 0.0254 - val_accuracy: 0.9942 Epoch 99/100 282/282 [==============================] - 32s 113ms/step - loss: 0.0225 - accuracy: 0.9935 - val_loss: 0.0270 - val_accuracy: 0.9942 Epoch 100/100 282/282 [==============================] - 32s 113ms/step - loss: 0.0228 - accuracy: 0.9931 - val_loss: 0.0274 - val_accuracy: 0.9940
model.evaluate(x_test[...,np.newaxis], y_test)
313/313 [==============================] - 13s 40ms/step - loss: 0.0250 - accuracy: 0.9920
[0.024966726079583168, 0.9919999837875366]
plt.figure(1, (12, 4))
plt.subplot(1, 2, 1)
plt.plot(results.history['loss'])
plt.plot(results.history['val_loss'])
plt.ylabel('loss')
plt.xlabel('epoch')
plt.legend(['train', 'val'], loc='upper right')
plt.subplot(1, 2, 2)
plt.plot(results.history['accuracy'])
plt.plot(results.history['val_accuracy'])
plt.ylabel('accuracy')
plt.xlabel('epoch')
plt.legend(['train', 'val'], loc='upper left')
plt.tight_layout()