import numpy as np
rng = np.random.default_rng()
import matplotlib.pyplot as plt
from matplotlib.colors import ListedColormap, TABLEAU_COLORS
import seaborn as sns
# load mlp from github
!git clone https://github.com/StephenTGibson/homemade-artificial-neural-networks.git &> /dev/null
import sys
sys.path.append('/content/homemade-artificial-neural-networks')
import multilayerPerceptron as mlp
Create a dataset
objectsPerClass = 10
variables = 2
classes = 4
data = mlp.createDataset_clouds(objectsPerClass, variables, classes, step=1.5)
sns.scatterplot(
x=data[:, 0],
y=data[:, 1],
hue=data[:, 2],
palette='tab10',
)
plt.show()
Define the network
Layers take following args: number inputs, number outputs, activation function
model = mlp.MultilayerPerceptron(
mlp.Layer(variables, 6, 'sigmoid'),
[
mlp.Layer(6, 6, 'sigmoid'),
],
mlp.Layer(6, classes, 'softmax'),
learningRate=0.001,
momentum=0.99,
lossFunction=mlp.multiCrossEntropy,
)
Train the network (online SGD)
epochs = 2000
history = model.train(data, epochs)
print(f'Final mean loss: {(history[-1]):.3g}')
predictions = model.test(data[:, :-1])
correct = (data[:, -1] == predictions) * 1.
print(f'Final training accuracy: {(100 * correct.sum() / data.shape[0]):.1f}%')
Epoch: 0 Classification accuracy: 32.5% Epoch: 400 Classification accuracy: 82.5% Epoch: 800 Classification accuracy: 85.0% Epoch: 1200 Classification accuracy: 85.0% Epoch: 1600 Classification accuracy: 87.5% Final mean loss: 0.271 Final training accuracy: 87.5%
Plot training loss and learnt decision boundary
# colours for mesh plot
keys = [num for num in range(len(TABLEAU_COLORS))]
vals = [color for color in TABLEAU_COLORS.keys()]
cDict = dict(zip(keys, vals))
cMap = ListedColormap([cDict[num] for num in range(classes)])
fig, axs = plt.subplots(1, 2, figsize=(12, 6), constrained_layout=True)
axs[0].plot(history)
xMesh, yMesh, labelsMesh, confidenceMesh = mlp.createDecisionBoundaryMesh('multiClass', model, data)
axs[1].pcolormesh(xMesh, yMesh, labelsMesh, cmap=cMap, shading='auto')
axs[1].pcolormesh(xMesh, yMesh, confidenceMesh, cmap='Greys', alpha=0.5, shading='auto')
sns.scatterplot(
ax=axs[1],
x=data[:, 0],
y=data[:, 1],
hue=data[:, -1],
style=correct,
palette=cDict,
markers={0: 'X', 1: 'o'},
edgecolor='black',
linewidth=1.5,
)
axs[1].get_legend().remove()
plt.show()