This notebook replicates what was alrady done in the previous example but using functions in the classification
module. There is no new ML example here. The purpose of this notebook is to make the whole thing short and concise, so that you can use this as a testbed to develop different networks more easily.
from __future__ import print_function
from IPython.display import display
import torch, time
import numpy as np
%matplotlib inline
Let us define the same network as the the previous example.
from hkml_resnet import ResNet
class BLOB:
pass
blob=BLOB()
blob.net = ResNet(3,2,16,[2,2,2,2,2]).cuda() # construct Lenet for 3 class classification, use GPU
blob.criterion = torch.nn.CrossEntropyLoss() # use softmax loss to define an error
blob.optimizer = torch.optim.Adam(blob.net.parameters(),weight_decay=0.001) # use Adam optimizer algorithm
blob.softmax = torch.nn.Softmax(dim=1) # not for training, but softmax score for each class
blob.data = None # data for training/analysis
blob.label = None # label for training/analysis
# Create data loader
from iotools import loader_factory
DATA_DIRS=['/scratch/hkml_data/IWCDgrid/varyE/e-','/scratch/hkml_data/IWCDgrid/varyE/mu-','/scratch/hkml_data/IWCDgrid/varyE/gamma']
# for train
blob.train_loader=loader_factory('H5Dataset', batch_size=64, shuffle=True, num_workers=4, data_dirs=DATA_DIRS, flavour='100k.h5', start_fraction=0.0, use_fraction=0.2)
# for validation
blob.test_loader=loader_factory('H5Dataset', batch_size=200, shuffle=True, num_workers=2, data_dirs=DATA_DIRS, flavour='100k.h5', start_fraction=0.1, use_fraction=0.1)
# Create & attach data recording utility (into csv file)
from utils import CSVData
blob.train_log, blob.test_log = CSVData('log_train.csv'), CSVData('log_test.csv')
from classification import train_loop
train_loop(blob,10.0)
Epoch 0 Starting @ 2019-04-16 13:39:26
Epoch 1 Starting @ 2019-04-16 13:40:43
Epoch 2 Starting @ 2019-04-16 13:42:00
Epoch 3 Starting @ 2019-04-16 13:43:18
Epoch 4 Starting @ 2019-04-16 13:44:34
Epoch 5 Starting @ 2019-04-16 13:45:52
Epoch 6 Starting @ 2019-04-16 13:47:09
Epoch 7 Starting @ 2019-04-16 13:48:26
Epoch 8 Starting @ 2019-04-16 13:49:43
Epoch 9 Starting @ 2019-04-16 13:50:58
from classification import plot_log
plot_log(blob.train_log.name,blob.test_log.name)
from classification import inference
accuracy,label,prediction = inference(blob,blob.test_loader)
print('Accuracy mean',accuracy.mean(),'std',accuracy.std())
Accuracy mean 0.80390006 std 0.027903825
Plot the confusion matrix
from utils import plot_confusion_matrix
plot_confusion_matrix(label,prediction,['gamma','electron','muon'])