Joeri Hermans (Technical Student, IT-DB-SAS, CERN)
Departement of Knowledge Engineering
Maastricht University, The Netherlands
!(date +%d\ %B\ %G)
18 January 2017
In this notebook we will show you how to process the MNIST dataset using Distributed Keras. As in the workflow notebook, we will guide you through the complete machine learning pipeline.
To get started, we first load all the required imports. Please make sure you installed dist-keras
, and seaborn
. Furthermore, we assume that you have access to an installation which provides Apache Spark.
Before you start this notebook, place make sure you ran the "MNIST preprocessing" notebook first, since we will be evaluating a manually "enlarged dataset".
%matplotlib inline
import numpy as np
from keras.optimizers import *
from keras.models import Sequential
from keras.layers.core import *
from keras.layers.convolutional import *
from pyspark import SparkContext
from pyspark import SparkConf
from matplotlib import pyplot as plt
from pyspark import StorageLevel
from pyspark.ml.feature import StandardScaler
from pyspark.ml.feature import VectorAssembler
from pyspark.ml.feature import OneHotEncoder
from pyspark.ml.feature import MinMaxScaler
from pyspark.ml.feature import StringIndexer
from pyspark.ml.evaluation import MulticlassClassificationEvaluator
from distkeras.trainers import *
from distkeras.predictors import *
from distkeras.transformers import *
from distkeras.evaluators import *
from distkeras.utils import *
Using TensorFlow backend.
In the following cell, adapt the parameters to fit your personal requirements.
# Modify these variables according to your needs.
application_name = "Distributed Keras MNIST Analysis"
using_spark_2 = False
local = False
path = "mnist.parquet"
if local:
# Tell master to use local resources.
master = "local[*]"
num_processes = 3
num_executors = 1
else:
# Tell master to use YARN.
master = "yarn-client"
num_executors = 30
num_processes = 1
# This variable is derived from the number of cores and executors, and will be used to assign the number of model trainers.
num_workers = num_executors * num_processes
print("Number of desired executors: " + `num_executors`)
print("Number of desired processes / executor: " + `num_processes`)
print("Total number of workers: " + `num_workers`)
Number of desired executors: 30 Number of desired processes / executor: 1 Total number of workers: 30
conf = SparkConf()
conf.set("spark.app.name", application_name)
conf.set("spark.master", master)
conf.set("spark.executor.cores", `num_processes`)
conf.set("spark.executor.instances", `num_executors`)
conf.set("spark.locality.wait", "0")
conf.set("spark.executor.memory", "5g")
conf.set("spark.serializer", "org.apache.spark.serializer.KryoSerializer");
# Check if the user is running Spark 2.0 +
if using_spark_2:
sc = SparkSession.builder.config(conf=conf) \
.appName(application_name) \
.getOrCreate()
else:
# Create the Spark context.
sc = SparkContext(conf=conf)
# Add the missing imports
from pyspark import SQLContext
sqlContext = SQLContext(sc)
# Check if we are using Spark 2.0
if using_spark_2:
reader = sc
else:
reader = sqlContext
# Read the training and test set.
training_set = reader.read.parquet('data/mnist_train_big.parquet') \
.select("features_normalized_dense", "label_encoded", "label")
test_set = reader.read.parquet('data/mnist_test_preprocessed.parquet') \
.select("features_normalized_dense", "label_encoded", "label")
# Print the schema of the dataset.
training_set.printSchema()
root |-- features_normalized_dense: vector (nullable = true) |-- label_encoded: vector (nullable = true) |-- label: long (nullable = true)
mlp = Sequential()
mlp.add(Dense(1000, input_shape=(784,)))
mlp.add(Activation('relu'))
mlp.add(Dropout(0.2))
mlp.add(Dense(200))
mlp.add(Activation('relu'))
mlp.add(Dropout(0.2))
mlp.add(Dense(10))
mlp.add(Activation('softmax'))
mlp.summary()
____________________________________________________________________________________________________ Layer (type) Output Shape Param # Connected to ==================================================================================================== dense_1 (Dense) (None, 1000) 785000 dense_input_1[0][0] ____________________________________________________________________________________________________ activation_1 (Activation) (None, 1000) 0 dense_1[0][0] ____________________________________________________________________________________________________ dropout_1 (Dropout) (None, 1000) 0 activation_1[0][0] ____________________________________________________________________________________________________ dense_2 (Dense) (None, 200) 200200 dropout_1[0][0] ____________________________________________________________________________________________________ activation_2 (Activation) (None, 200) 0 dense_2[0][0] ____________________________________________________________________________________________________ dropout_2 (Dropout) (None, 200) 0 activation_2[0][0] ____________________________________________________________________________________________________ dense_3 (Dense) (None, 10) 2010 dropout_2[0][0] ____________________________________________________________________________________________________ activation_3 (Activation) (None, 10) 0 dense_3[0][0] ==================================================================================================== Total params: 987210 ____________________________________________________________________________________________________
optimizer_mlp = 'adam'
loss_mlp = 'categorical_crossentropy'
Prepare the training and test set for evaluation and training.
training_set = training_set.repartition(num_workers)
test_set = test_set.repartition(num_workers)
training_set.cache()
test_set.cache()
print("Number of training instances: " + str(training_set.count()))
print("Number of testing instances: " + str(test_set.count()))
Number of training instances: 6060000 Number of testing instances: 10000
We define a utility function which will compute the accuracy for us.
def evaluate_accuracy(model, test_set, features="features_normalized_dense"):
evaluator = AccuracyEvaluator(prediction_col="prediction_index", label_col="label")
predictor = ModelPredictor(keras_model=model, features_col=features)
transformer = LabelIndexTransformer(output_dim=10)
test_set = test_set.select(features, "label")
test_set = predictor.predict(test_set)
test_set = transformer.transform(test_set)
score = evaluator.evaluate(test_set)
return score
trainer = ADAG(keras_model=mlp, worker_optimizer=optimizer_mlp, loss=loss_mlp, num_workers=num_workers,
batch_size=4, communication_window=5, num_epoch=1,
features_col="features_normalized_dense", label_col="label_encoded")
# Modify the default parallelism factor.
trained_model = trainer.train(training_set)
# View the weights of the trained model.
trained_model.get_weights()
[array([[-0.02490237, -0.01861665, 0.03102627, ..., 0.01722135, 0.02223415, -0.04933412], [-0.02634868, 0.03564246, -0.05392314, ..., -0.02999102, -0.01270337, -0.03888189], [ 0.00727941, 0.04553502, -0.01856072, ..., 0.0319587 , -0.00354035, -0.03581727], ..., [-0.03245988, -0.01220334, 0.019447 , ..., 0.05723321, -0.05618715, -0.0248918 ], [-0.02532675, -0.01772211, 0.05514754, ..., 0.03839124, -0.05036234, -0.03766601], [ 0.04610632, 0.01409597, 0.03790993, ..., -0.02038677, -0.03649681, 0.04742099]], dtype=float32), array([ -1.29682487e-02, 1.38744503e-01, -3.10007334e-01, -3.04996595e-02, -1.39434069e-01, -4.05185074e-02, -2.09797233e-01, -4.62490469e-01, -6.72216356e-01, -1.83647368e-02, -2.93090612e-01, 5.11649624e-02, -2.74094105e-01, -9.03906003e-02, -7.21242726e-01, -2.51375604e-02, -1.40052319e-01, -1.31754786e-01, -1.88921779e-01, -3.18406552e-01, -3.45931239e-02, -1.89292878e-01, 3.80539931e-02, 3.54425013e-02, -6.34538352e-01, -2.27093436e-02, -5.49978614e-01, -2.85222325e-02, -4.87636119e-01, -2.94719964e-01, -4.62469608e-01, -4.31859016e-01, -4.95594800e-01, -7.55963206e-01, -7.07836151e-01, 5.50588481e-02, 1.01570776e-02, -3.62383217e-01, -2.37895608e-01, -3.48139226e-01, -5.14193960e-02, -4.49353665e-01, -2.04702299e-02, -1.28980473e-01, -6.01515993e-02, -4.11046803e-01, -2.73511171e-01, -4.22501177e-01, 6.57678917e-02, -3.77899945e-01, -3.68858546e-01, -3.45079124e-01, -1.21501423e-01, -2.59954304e-01, -2.77339309e-01, 7.24700987e-02, -1.75704360e-01, -1.79602101e-01, -3.49472016e-01, -4.22441006e-01, -3.98772031e-01, 4.78056073e-02, 1.63912345e-02, -1.73481293e-02, 2.03711018e-01, -1.66458517e-01, -2.50248574e-02, -4.33256328e-01, -1.77355483e-02, -6.68845698e-02, -6.33655787e-02, -2.07219645e-01, -2.81381667e-01, -2.10354477e-01, 9.65033993e-02, 1.45252123e-01, -1.62108362e-01, -4.10078391e-02, -5.01093924e-01, 6.61657602e-02, -3.54006797e-01, -2.72664815e-01, -4.63590562e-01, -2.76888013e-01, 5.67168836e-03, -1.63264722e-02, -5.64372167e-02, -3.27719487e-02, -1.25738844e-01, -3.16582769e-02, -3.16652000e-01, 2.20678657e-01, -4.90398854e-01, -3.87180448e-01, 4.62217331e-02, -3.87124509e-01, 3.44271868e-01, -6.47646427e-01, -4.47504744e-02, -3.12687427e-01, -3.64519686e-01, -1.19691178e-01, -1.22579239e-01, -1.74031451e-01, -3.50467891e-01, -3.85930926e-01, -1.01258140e-02, 1.65355578e-01, 2.38174275e-02, -3.86843532e-01, -2.11541757e-01, -1.60455573e-02, -3.41660500e-01, -2.41097137e-01, -3.58184397e-01, -3.74646991e-01, -5.68306029e-01, 6.03663735e-02, -2.25287676e-01, -3.33954960e-01, -3.21863830e-01, -5.74063025e-02, -9.54797715e-02, -1.69863552e-01, 5.25663458e-02, -1.78944767e-01, -4.96068239e-01, -9.37457308e-02, -4.91037033e-02, -5.45800686e-01, -4.19147074e-01, -3.63402218e-01, -9.55256671e-02, -6.56951070e-02, -4.74279895e-02, 3.94136347e-02, -6.89108312e-01, -6.40569270e-01, -2.92730868e-01, -4.21674043e-01, -9.05798003e-02, -9.85799953e-02, -3.34262311e-01, -2.91352630e-01, -1.20481804e-01, -1.30824670e-01, -3.15101117e-01, -3.82897407e-01, -3.67818296e-01, -2.51174152e-01, -4.45220284e-02, -3.63316804e-01, -5.95236719e-01, -3.27549487e-01, -5.18906057e-01, -1.80942759e-01, -1.93147764e-01, -1.63675278e-01, 5.25709763e-02, -1.69222236e-01, -1.66612849e-01, -1.89764783e-01, 9.59388837e-02, -1.79865390e-01, -2.87416220e-01, -1.37040511e-01, -3.68917108e-01, -1.97503880e-01, -4.80307907e-01, -9.74704884e-03, -1.62035048e-01, -4.33685966e-02, -3.75206321e-01, -2.71574229e-01, -2.51338482e-01, -1.91602707e-01, -4.66123730e-01, -3.09535444e-01, -3.18885483e-02, -3.23637798e-02, -3.71796012e-01, -2.26407617e-01, -4.69909385e-02, -3.70391518e-01, -5.37406743e-01, -5.00004053e-01, -4.49130647e-02, 1.55784473e-01, -3.39550585e-01, -5.15295863e-01, -5.79936266e-01, 4.80024889e-03, -1.23718642e-01, -6.55675307e-02, -2.74233013e-01, -2.67147571e-01, -4.20176655e-01, -2.30046362e-02, -2.80579627e-01, -6.52074635e-01, -2.07271874e-01, -3.34823787e-01, -5.11079669e-01, -4.89039391e-01, -1.69896662e-01, -6.09769404e-01, 1.67333558e-01, -1.52619872e-02, -1.82103708e-01, -1.59035064e-02, -2.82586038e-01, -4.48576622e-02, -2.77401984e-01, -1.18868940e-01, -3.09958905e-01, -4.54939663e-01, -6.84868218e-03, -1.78479820e-01, -4.12694991e-01, -4.86943096e-01, -4.83419180e-01, -2.92061418e-01, -3.56696308e-01, -2.38492072e-01, -1.99521467e-01, -6.62643433e-01, -6.58789635e-01, -3.13386142e-01, -2.39210613e-02, 3.81695509e-01, 3.89514342e-02, -4.21914130e-01, -1.78643346e-01, -3.58139843e-01, -2.31155585e-02, -5.25866091e-01, -2.01350115e-02, 1.34515122e-01, -4.72941786e-01, 1.28511051e-02, -1.92628369e-01, -2.94919074e-01, -1.21810228e-01, -2.63900816e-01, -1.77175865e-01, -3.85966711e-02, -3.91167760e-01, -3.54940116e-01, -4.08377945e-02, -2.46946454e-01, -1.70614153e-01, 9.64559093e-02, -1.58487067e-01, -1.40857771e-01, -2.60191988e-02, -2.16996279e-02, -2.01046526e-01, 1.07773796e-01, -7.25519285e-02, -4.59324010e-02, -3.97602469e-01, -2.86683738e-01, -2.06594560e-02, -2.32254282e-01, -1.47455707e-01, -2.11738929e-01, -3.97648931e-01, -1.92232862e-01, -4.22664315e-01, -2.10082695e-01, -3.69767874e-01, -3.35989922e-01, -2.50372291e-02, -2.56772131e-01, -7.55918026e-01, -1.45749766e-02, -5.94904542e-01, -1.83992922e-01, -1.98239967e-01, 2.28624657e-01, -3.67346585e-01, -2.17467710e-01, -8.19451883e-02, -5.01424968e-02, -3.00576668e-02, 2.42029456e-03, -6.11475348e-01, -2.48637870e-01, -1.25368005e-02, -1.07831452e-02, 3.56794626e-01, -2.73973256e-01, -5.00894673e-02, -3.93987626e-01, -6.70151055e-01, 5.03201634e-02, -3.47819924e-01, 2.21592330e-04, -9.35477093e-02, -4.01370734e-01, -5.17268419e-01, -2.08003540e-02, -1.58300679e-02, 1.09454863e-01, 4.86627640e-03, -4.40006703e-01, 1.10145152e-01, -3.08435559e-01, -2.27646939e-02, -6.15591705e-02, -6.83150813e-02, 1.51192188e-01, -2.93954074e-01, 1.76271528e-01, -5.47897398e-01, -2.94454783e-01, -4.87583935e-01, -2.25682836e-02, -2.61891991e-01, -2.05876276e-01, -2.91871820e-02, -4.65158612e-01, -1.10427953e-01, 2.59957045e-01, -6.44603491e-01, -5.89241982e-01, -2.40099952e-01, -2.48620026e-02, 2.60877088e-02, -3.69062722e-01, -5.85998118e-01, 6.35902397e-04, 1.52950898e-01, -1.31705374e-01, -6.95600629e-01, -6.93177283e-02, -3.34524751e-01, -2.05166377e-02, -4.04433101e-01, -3.34488690e-01, 4.12484966e-02, -1.07743412e-01, -2.31767640e-01, -5.87181449e-01, -1.24916852e-01, -2.45317779e-02, -4.82061923e-01, 4.29915352e-04, -2.29062542e-01, -1.53157920e-01, -8.75511765e-02, -1.93034634e-01, -2.39149824e-01, -2.81021118e-01, -1.92091212e-01, 4.84096706e-02, -3.15482467e-01, -9.38970945e-04, -7.32823536e-02, 1.46180347e-01, -7.48398662e-01, -2.95927972e-01, -1.01935327e-01, -2.25223079e-02, -3.76603395e-01, -3.72446418e-01, -5.44973463e-02, -3.04856654e-02, -8.12882781e-01, -6.35300994e-01, 1.01717256e-01, 1.15769980e-02, 1.94745436e-01, -4.62203443e-01, -1.94413647e-01, -1.19787067e-01, 5.01835823e-01, -1.22532628e-01, -4.83275265e-01, -5.72950900e-01, -1.68230399e-01, -2.53478941e-02, -8.93718377e-02, -2.09907755e-01, 1.15736432e-01, 7.35889524e-02, -2.25963101e-01, -1.25411734e-01, -1.58686683e-01, 3.05348307e-01, -4.07805927e-02, -6.87129676e-01, -1.78614125e-01, -6.12517297e-02, -1.26590893e-01, -5.44444025e-01, -2.87909880e-02, -1.61622658e-01, -6.28022432e-01, -3.93144011e-01, -4.14166540e-01, -3.36472809e-01, -2.14290902e-01, -1.57012552e-01, -6.99233487e-02, -1.79140717e-01, -3.44865173e-01, -4.32067961e-01, -4.17658724e-02, -1.92612112e-01, -4.07513529e-01, -2.00688168e-01, -3.12940218e-02, -5.83245270e-02, -3.02525491e-01, -6.36755228e-01, -2.01398991e-02, -1.94140598e-01, -5.85560381e-01, -2.78204322e-01, -4.92228866e-01, 2.85394281e-01, -5.29185772e-01, -5.80944479e-01, -4.82267290e-01, -3.02456468e-01, -2.17350312e-02, -2.27617443e-01, -8.41379631e-03, -5.19459188e-01, -1.92483932e-01, -6.69973344e-02, -3.18294495e-01, -4.43626344e-01, 1.03083804e-01, -1.43494621e-01, -3.98965865e-01, -2.91880131e-01, -1.15407094e-01, -2.33865350e-01, -3.48333865e-01, -3.13846886e-01, -2.00329088e-02, -2.08419889e-01, -6.56257868e-02, -3.15933287e-01, -2.66032100e-01, -2.17209011e-01, -2.57886738e-01, -3.74219060e-01, -3.42252910e-01, -3.02372843e-01, -2.70351022e-01, -4.19028729e-01, -2.16944158e-01, 1.65465083e-02, -1.38239786e-01, 8.82068649e-03, -5.47306299e-01, -6.58184737e-02, -1.07372276e-01, -1.99595578e-02, -3.04633468e-01, -2.42436364e-01, -9.85036939e-02, 8.13045427e-02, -6.01692021e-01, -7.83374131e-01, -3.54873002e-01, -1.54401422e-01, -1.99920405e-02, -6.02073036e-03, -7.46182263e-01, -5.17743170e-01, -1.43411651e-01, 1.35698587e-01, -4.32992607e-01, -3.22256982e-01, 2.01625749e-01, -1.68692529e-01, 9.03868079e-02, -7.36883581e-02, -2.26779003e-02, 7.53887817e-02, -3.51618379e-01, -6.96502507e-01, -1.97232455e-01, -2.19720408e-01, -1.76197141e-01, -3.31067145e-01, 2.52920628e-01, -5.32557011e-01, -9.84433852e-03, -2.28284430e-02, -2.18466327e-01, -2.50813589e-02, -1.22822799e-01, -6.21357895e-02, -1.85140949e-02, 1.55188337e-01, -2.91802138e-01, -1.76329892e-02, -3.60844210e-02, -5.81378281e-01, -6.11039221e-01, -3.28095675e-01, -2.83731908e-01, -1.66193381e-01, 5.52292354e-02, 6.29878119e-02, -3.41305107e-01, -1.39835373e-01, 1.71938047e-01, -1.84613727e-02, 7.50863180e-02, -3.44148017e-02, -3.53854299e-01, -5.12476027e-01, 1.22042328e-01, -5.39535470e-02, 3.05281021e-03, -1.19409911e-01, -2.89323032e-01, -6.71940520e-02, -2.19452642e-02, -2.90004104e-01, -1.76387712e-01, -4.56134796e-01, -8.09880495e-01, -1.83778346e-01, -2.31890544e-01, -4.52327728e-01, -2.06816241e-01, -1.38748497e-01, -4.18441355e-01, -5.38856745e-01, -5.05130768e-01, -1.75971299e-01, -1.19080685e-01, -9.46213081e-02, -3.64823714e-02, -3.22997957e-01, -1.34447142e-01, -1.27073288e-01, 1.64654911e-01, -9.78678912e-02, -4.47389364e-01, -2.54144296e-02, 1.73969138e-02, -2.04480872e-01, -4.30503398e-01, -1.67036086e-01, -2.49711365e-01, -3.37412119e-01, -6.02359474e-01, -6.62094355e-01, -1.16948448e-01, 9.77696292e-03, -5.21902740e-01, -2.33485606e-02, -6.64649755e-02, -6.00027978e-01, -5.42070754e-02, -2.38561943e-01, -4.47000265e-01, 1.17274612e-01, -1.11540303e-01, -1.02203742e-01, -6.74192980e-02, -1.72974497e-01, -2.43933983e-02, -2.18470603e-01, -1.02555685e-01, -5.01730680e-01, -1.63745075e-01, -2.48166338e-01, 4.25796956e-02, -8.81046131e-02, -4.94634926e-01, -2.48743445e-01, 8.22583865e-03, -2.14855313e-01, -5.94667614e-01, 1.23224966e-01, -2.28983104e-01, -4.89580818e-02, -3.53976309e-01, -1.02518976e-01, -2.80924350e-01, 2.18932718e-01, -9.42684943e-04, -2.78814733e-01, -2.43697301e-01, -4.07780051e-01, -1.57622676e-02, -4.32732075e-01, 2.76384447e-02, -2.56971091e-01, -1.39276221e-01, -2.89412320e-01, -7.84103293e-03, -5.75612962e-01, -2.65779234e-02, -2.83633530e-01, -2.42152084e-02, -3.54716778e-01, -5.25303543e-01, -6.30853772e-02, -2.22892091e-01, -3.32897723e-01, -8.58137235e-02, -1.35768950e-01, -4.00102228e-01, -6.81776628e-02, -1.11637965e-01, 8.71941745e-02, 7.97185600e-02, -4.74733919e-01, -5.36120776e-03, -2.00053956e-02, 2.74125468e-02, -5.23373425e-01, -3.52810740e-01, -5.75067937e-01, -1.27765425e-02, -2.41196215e-01, 1.35370884e-02, -3.42776716e-01, -2.61937886e-01, -1.73471346e-01, -7.74265826e-01, -3.25414896e-01, -6.52070194e-02, -1.75177939e-02, -2.78512776e-01, -1.26804650e-01, -1.54330492e-01, -2.43354395e-01, -5.10048628e-01, -5.22104055e-02, -4.48061913e-01, -2.54915148e-01, -3.71145964e-01, -2.34785691e-01, -5.76828778e-01, -5.20584345e-01, -2.01370478e-01, -3.43574703e-01, -3.95394504e-01, -7.02085435e-01, 3.80159239e-03, -5.05006194e-01, -6.66690245e-02, -2.13820174e-01, -1.86356172e-01, -1.98591515e-01, -2.26664558e-01, -9.84562710e-02, 9.10461769e-02, -1.63858235e-01, -6.71461642e-01, -2.07045935e-02, -1.84064224e-01, -1.52253630e-02, -6.44623414e-02, -1.90693051e-01, -3.26317549e-01, -3.90465967e-02, -4.31612767e-02, -2.69320831e-02, -2.61054486e-01, -5.56032240e-01, -1.39396250e-01, -3.04626554e-01, -4.00418974e-02, -5.22964954e-01, -2.74515212e-01, -2.05182180e-01, -4.55017984e-01, -4.10655349e-01, -3.91681463e-01, -2.95707285e-01, -1.75162852e-02, -1.80232033e-01, -9.38054398e-02, -4.48614866e-01, -1.20916396e-01, -1.26026660e-01, -6.13098264e-01, -9.16779786e-02, -1.24931745e-01, -1.14639051e-01, -5.89349389e-01, -2.86892831e-01, -4.32475626e-01, -4.53839451e-01, -5.40873766e-01, -3.22011739e-01, -1.04171380e-01, -2.03116417e-01, -7.34383706e-03, -2.95767933e-01, 3.77100818e-02, -3.95163864e-01, -9.11748350e-01, -2.14269429e-01, -4.47106093e-01, -1.02919694e-02, -1.46425188e-01, 1.30215868e-01, 3.46448004e-01, -7.53604919e-02, -3.68188143e-01, -1.75004661e-01, -3.42096955e-01, -1.19322361e-02, 9.38493479e-03, -5.18787801e-01, -1.09108455e-01, 6.15557991e-02, -8.33496079e-03, -6.41730651e-02, -1.36719868e-02, -3.73748362e-01, -3.73859495e-01, 2.80248914e-02, -3.09117913e-01, -2.88713902e-01, -4.28494245e-01, -5.13740003e-01, -1.57594740e-01, -4.70732421e-01, -1.38654308e-02, -6.85215056e-01, -3.66586596e-01, -1.41351402e-01, -1.13854766e-01, -5.36643863e-01, -4.75565642e-01, -5.00832915e-01, -4.08477843e-01, -3.66504490e-01, -1.15367234e-01, -2.48915218e-02, -4.96757418e-01, 1.17366053e-01, -2.26039514e-01, -5.49678802e-01, -2.75789142e-01, -5.08426309e-01, 1.07284091e-01, -2.54364550e-01, -3.72139484e-01, -3.34391892e-01, 2.10764147e-02, -1.33560911e-01, -9.50245783e-02, -3.13357562e-01, -2.62188077e-01, -5.32095313e-01, -5.31459413e-03, -3.21489833e-02, -7.84164011e-01, -1.10715240e-01, -2.87352562e-01, -5.71807444e-01, -2.04134420e-01, 7.85130933e-02, -3.69185776e-01, -1.98006928e-02, 6.63151639e-03, -2.87224799e-01, 5.36596589e-02, -7.96930939e-02, -2.82612413e-01, -1.87133670e-01, -6.54792845e-01, -8.59472081e-02, -1.13062121e-01, -1.83315545e-01, -2.58277714e-01, -5.51701725e-01, -5.59242129e-01, -1.50169775e-01, 4.73141856e-02, -1.68764800e-01, -2.75284111e-01, -4.43699747e-01, -2.76820183e-01, -3.51191200e-02, -1.07176892e-01, -4.73967902e-02, -4.53751475e-01, -2.84370124e-01, -4.89342690e-01, -3.81000303e-02, -5.29655755e-01, -1.50656566e-01, -4.64593619e-01, -1.58045471e-01, -7.06188157e-02, -4.04648870e-01, -3.15317452e-01, -2.87708908e-01, -1.71832666e-01, -2.27938369e-01, -2.11054739e-02, -3.29687774e-01, -1.82581544e-01, -2.17228252e-02, 2.08218992e-02, -1.46109968e-01, -7.96382129e-02, -3.17795098e-01, -5.75634658e-01, -3.44916396e-02, -4.36014533e-01, -2.85244137e-02, -5.68732560e-01, -5.59068859e-01, -1.22407533e-01, -2.56792486e-01, -2.97368616e-01, -3.03129584e-01, -1.62084669e-01, -2.64727145e-01, -4.05563980e-01, 3.00995618e-01, -1.86940640e-01, -9.05097499e-02, -1.19438395e-01, -1.88409179e-01, -3.68620992e-01, 3.19603570e-02, -5.20787895e-01, -2.95364499e-01, -1.96136490e-01, 1.30156171e+00, -3.09764799e-02, -1.63758829e-01, -1.63395420e-01, -1.06308326e-01, -3.37606370e-01, -4.02779371e-01, -1.04163669e-01, -3.29879135e-01, -6.24738149e-02, 7.57394284e-02, -6.51596487e-01, -2.37611696e-01, -5.25772333e-01, 1.44061729e-01, -2.59940475e-01, -2.72920489e-01, -3.10522407e-01, -8.48866284e-01, -5.29746771e-01, -1.75354518e-02, -8.73476788e-02, -4.62230533e-01, -3.12623024e-01, -4.66565102e-01, -2.35941991e-01, -4.72842991e-01, -8.59152302e-02, -3.31128508e-01, -1.34016275e-01, -6.82140663e-02, -1.31053597e-01, 3.27668451e-02, -4.59252357e-01, -7.40645081e-02, -2.32884094e-01, -2.48913141e-03, -5.38118541e-01, -6.48121983e-02, -2.82097995e-01, -4.83397216e-01, -3.75957131e-01, -1.20243065e-01, -2.91992631e-02, -2.34807402e-01, -8.57004896e-02, -1.76332936e-01, -4.79596853e-01, -3.59954983e-01, -3.86393666e-01, -1.49604112e-01, 9.89474952e-02, -1.43513409e-02, -5.00253379e-01, -2.31766224e-01, -2.78296471e-01, -1.47517323e-01, -2.70760179e-01, 5.62180728e-02, 1.26814142e-01, -2.58570649e-02, -3.02321255e-01, -5.06240189e-01, -3.60810488e-01, -1.61365643e-01, -1.28059566e-01, -2.62734950e-01, -1.67697724e-02, 9.22571719e-02, -7.30941415e-01, -3.17986846e-01, -3.49215209e-01, -4.75899428e-01, -5.54573357e-01, -2.22814456e-01, -9.33618564e-03, -4.88777943e-02, -2.79946309e-02, -2.43498668e-01, 1.63741887e-01, -8.86490270e-02, -1.80582032e-02, 5.81286959e-02, -5.06547272e-01, -2.36781448e-01, -2.82066971e-01, 3.62231545e-02, 5.59952706e-02, -5.27004182e-01, -5.63789010e-02, -6.33812070e-01, -7.20118701e-01, -3.27905029e-01, -1.09615184e-01, -1.97968498e-01, -3.48774903e-02, -4.36178327e-01, -1.90760285e-01, -2.00712010e-01, -4.05785292e-02, -7.98018798e-02, -6.48312092e-01, -5.16030610e-01, -1.82418972e-02, -3.22774321e-01, -1.91510841e-01, -1.31354675e-01, -5.67911983e-01, -4.27046567e-01, -2.61492878e-01, -7.63690919e-02, -3.53502780e-01, -2.86672637e-02, 6.57036155e-02, -2.32697666e-01, -2.25740999e-01, -2.21521795e-01, 3.64017077e-02, -4.65820670e-01, -1.67809874e-01, -2.34040041e-02, -3.40095460e-01, 5.10562137e-02, -2.80955017e-01, 2.17410009e-02, -2.25610495e-01, -2.61850543e-02, -1.18860357e-01, 9.67218876e-02, -6.98161423e-01, -4.03901875e-01, -2.49750782e-02, -1.49894670e-01, -1.55417640e-02, -2.35045440e-02, -1.22158304e-02, -3.60701740e-01, -5.72664201e-01, -4.56410229e-01, -9.86423045e-02, -5.59065938e-01, -2.43323550e-01, 1.14932351e-01, -1.32146357e-02, -1.13701306e-01, -2.43878905e-02, 3.04878563e-01, -2.93137670e-01, -4.26690668e-01, -1.90759376e-01, -5.80423713e-01, 1.61198322e-02, -3.25486124e-01, -3.21475148e-01, -2.53617167e-01, -1.20874017e-01, -4.76823658e-01, -3.47528964e-01, -2.89901286e-01, 2.24457998e-02, -4.97344643e-01, 1.08718812e+00, -2.79220223e-01], dtype=float32), array([[ 0.03900816, 0.00785677, -0.06511776, ..., 0.00776991, -0.05963232, -0.05985177], [-0.20750827, 0.08817152, 0.40323174, ..., 0.20854132, -0.11089708, 0.14705186], [-0.24851227, 0.36102909, 0.07329425, ..., 0.12305254, 0.02824712, 0.2746895 ], ..., [-0.27076459, 0.04397521, 0.10150083, ..., -0.02952144, 0.35495111, 0.01788467], [-0.22880824, -0.14765862, -0.01148497, ..., -0.04802479, -0.11898327, 0.16021334], [-0.01458607, 0.51388001, 0.25630933, ..., 0.10885861, -0.15997633, 0.01113635]], dtype=float32), array([-0.36252829, -0.41307127, -0.37561458, -0.790694 , -0.7867986 , -0.39656818, -0.49989551, -0.56961799, -0.67535901, -0.78190619, -0.64679927, -0.62336636, -0.73334086, -0.51707494, -0.80007225, -0.57039291, -0.43117863, -0.57423478, -1.01204598, -0.99576569, -0.45388478, -0.9715423 , -0.57562113, -0.85434681, -0.4783178 , -0.65333492, -0.56394655, -0.51519966, -0.87941819, -0.9431147 , -0.52889907, -0.51141596, -1.04037309, -0.87605566, -0.5586676 , -0.67145008, -0.62178028, -0.74712718, -0.47700772, -0.81794 , -0.94796181, -1.03332078, -0.99911004, -0.35762793, -0.41830212, -0.44990394, -0.54796964, -0.64622766, -0.36980084, -0.62949306, -0.73081511, -0.92071664, -0.96040893, -0.17141432, -0.50711352, -0.68742466, -0.58205402, -0.60873783, -0.51237881, -0.42307621, -0.59278268, -0.77905166, -0.70859444, -0.99470675, -0.68357819, -0.45728955, -0.98573047, -0.7740072 , -0.76561183, -0.38337517, -0.78785807, -0.9682638 , -0.41092423, -0.81709141, -0.4595961 , -0.45476505, -0.89052409, -0.95178139, -0.920165 , -0.83498871, -0.54309958, -0.62142682, -0.10648966, -0.55824465, -0.51698029, -0.65391433, -0.73073816, -0.63968295, -0.73563075, -0.37823838, -0.83874625, -0.35336301, -0.72945499, -0.61786187, -1.04557991, -0.58565521, -0.35223064, -0.30662736, -0.66361117, -0.74605358, -0.79575521, -1.12011874, -0.65195775, -0.66316205, -0.30292839, -0.97478765, -0.30300212, -0.98781288, -0.88087404, -0.56088251, -0.82704026, -0.57432526, -0.44808209, -0.65736598, -0.7800023 , -0.43863136, -0.71997589, -0.79668957, -0.58597511, -0.79392022, -0.91689253, -0.17079359, -0.70273119, -0.31935337, -0.99297088, -1.21429086, -0.54536754, -0.66847122, -1.0803057 , -0.02116329, -0.36946481, -0.78094089, -0.67028719, -0.63478422, -0.56762469, -0.59048861, -0.40834036, -0.76510531, -0.86944491, -0.26183733, -0.64363545, -0.21043499, -0.80520427, -0.98543239, -1.02239132, -0.87130302, -1.06532812, -0.47601402, -0.55352145, -0.75008106, -0.57477021, -0.73686802, -0.44472244, -0.64302158, -0.61648601, -1.09791934, -0.83204991, -0.40939972, -0.82405424, -0.57132626, -0.85813493, -0.84275389, -0.53043413, -1.03980398, -0.41696942, -0.99465734, -0.70751721, -0.94126099, -0.70646006, -0.85644752, -0.75323451, -0.62099051, -0.99225199, -0.81427616, -0.72105873, -0.3865678 , -0.71929121, -0.85359961, -0.47467613, -0.49992275, -0.78395241, -0.66783226, -0.85084015, -0.37230313, -0.74241304, -0.52368313, -0.57518154, -0.88761586, -0.78079957, -0.84552658, -0.60064358, -0.58771318, -0.68866116, -0.7030834 , -0.8059988 , -0.71570534, -0.56441271, -0.89694452, -0.83912975, -0.46641162], dtype=float32), array([[-0.78751951, 0.02826324, -0.07172652, ..., -0.27620244, -0.47863257, -0.49731782], [-0.49682441, 0.04474993, -0.77598727, ..., -0.54524791, -0.21792939, -0.47720003], [-0.2323969 , -0.88028777, -0.2349651 , ..., -0.14491257, -0.17279406, -0.64144588], ..., [-0.7111882 , -0.30641097, -0.66904122, ..., -0.0798426 , -0.57756215, -0.08725328], [ 0.11830693, 0.07352046, 0.08562858, ..., 0.09446803, -0.41451645, -0.35526502], [-0.92134595, 0.0993112 , -0.0636774 , ..., -0.0216356 , -0.54615569, -0.05519475]], dtype=float32), array([-0.28950188, -0.33981469, -0.49054769, -0.24692491, -0.54108179, -0.53850734, -0.51629019, -0.45034203, 0.94987106, 0.34385717], dtype=float32)]
print("Training time: " + str(trainer.get_training_time()))
print("Accuracy: " + str(evaluate_accuracy(trained_model, test_set)))
Training time: 22619.2383449 Accuracy: 0.9859