%use deeplearning4j
%use krangl
%use lets-plot
val iris_data = "sepal-length,sepal-width,petal-length,petal-width,species\n5.1,3.5,1.4,0.2,Iris-setosa\n4.9,3.0,1.4,0.2,Iris-setosa\n4.7,3.2,1.3,0.2,Iris-setosa\n4.6,3.1,1.5,0.2,Iris-setosa\n5.0,3.6,1.4,0.2,Iris-setosa\n5.4,3.9,1.7,0.4,Iris-setosa\n4.6,3.4,1.4,0.3,Iris-setosa\n5.0,3.4,1.5,0.2,Iris-setosa\n4.4,2.9,1.4,0.2,Iris-setosa\n4.9,3.1,1.5,0.1,Iris-setosa\n5.4,3.7,1.5,0.2,Iris-setosa\n4.8,3.4,1.6,0.2,Iris-setosa\n4.8,3.0,1.4,0.1,Iris-setosa\n4.3,3.0,1.1,0.1,Iris-setosa\n5.8,4.0,1.2,0.2,Iris-setosa\n5.7,4.4,1.5,0.4,Iris-setosa\n5.4,3.9,1.3,0.4,Iris-setosa\n5.1,3.5,1.4,0.3,Iris-setosa\n5.7,3.8,1.7,0.3,Iris-setosa\n5.1,3.8,1.5,0.3,Iris-setosa\n5.4,3.4,1.7,0.2,Iris-setosa\n5.1,3.7,1.5,0.4,Iris-setosa\n4.6,3.6,1.0,0.2,Iris-setosa\n5.1,3.3,1.7,0.5,Iris-setosa\n4.8,3.4,1.9,0.2,Iris-setosa\n5.0,3.0,1.6,0.2,Iris-setosa\n5.0,3.4,1.6,0.4,Iris-setosa\n5.2,3.5,1.5,0.2,Iris-setosa\n5.2,3.4,1.4,0.2,Iris-setosa\n4.7,3.2,1.6,0.2,Iris-setosa\n4.8,3.1,1.6,0.2,Iris-setosa\n5.4,3.4,1.5,0.4,Iris-setosa\n5.2,4.1,1.5,0.1,Iris-setosa\n5.5,4.2,1.4,0.2,Iris-setosa\n4.9,3.1,1.5,0.1,Iris-setosa\n5.0,3.2,1.2,0.2,Iris-setosa\n5.5,3.5,1.3,0.2,Iris-setosa\n4.9,3.1,1.5,0.1,Iris-setosa\n4.4,3.0,1.3,0.2,Iris-setosa\n5.1,3.4,1.5,0.2,Iris-setosa\n5.0,3.5,1.3,0.3,Iris-setosa\n4.5,2.3,1.3,0.3,Iris-setosa\n4.4,3.2,1.3,0.2,Iris-setosa\n5.0,3.5,1.6,0.6,Iris-setosa\n5.1,3.8,1.9,0.4,Iris-setosa\n4.8,3.0,1.4,0.3,Iris-setosa\n5.1,3.8,1.6,0.2,Iris-setosa\n4.6,3.2,1.4,0.2,Iris-setosa\n5.3,3.7,1.5,0.2,Iris-setosa\n5.0,3.3,1.4,0.2,Iris-setosa\n7.0,3.2,4.7,1.4,Iris-versicolor\n6.4,3.2,4.5,1.5,Iris-versicolor\n6.9,3.1,4.9,1.5,Iris-versicolor\n5.5,2.3,4.0,1.3,Iris-versicolor\n6.5,2.8,4.6,1.5,Iris-versicolor\n5.7,2.8,4.5,1.3,Iris-versicolor\n6.3,3.3,4.7,1.6,Iris-versicolor\n4.9,2.4,3.3,1.0,Iris-versicolor\n6.6,2.9,4.6,1.3,Iris-versicolor\n5.2,2.7,3.9,1.4,Iris-versicolor\n5.0,2.0,3.5,1.0,Iris-versicolor\n5.9,3.0,4.2,1.5,Iris-versicolor\n6.0,2.2,4.0,1.0,Iris-versicolor\n6.1,2.9,4.7,1.4,Iris-versicolor\n5.6,2.9,3.6,1.3,Iris-versicolor\n6.7,3.1,4.4,1.4,Iris-versicolor\n5.6,3.0,4.5,1.5,Iris-versicolor\n5.8,2.7,4.1,1.0,Iris-versicolor\n6.2,2.2,4.5,1.5,Iris-versicolor\n5.6,2.5,3.9,1.1,Iris-versicolor\n5.9,3.2,4.8,1.8,Iris-versicolor\n6.1,2.8,4.0,1.3,Iris-versicolor\n6.3,2.5,4.9,1.5,Iris-versicolor\n6.1,2.8,4.7,1.2,Iris-versicolor\n6.4,2.9,4.3,1.3,Iris-versicolor\n6.6,3.0,4.4,1.4,Iris-versicolor\n6.8,2.8,4.8,1.4,Iris-versicolor\n6.7,3.0,5.0,1.7,Iris-versicolor\n6.0,2.9,4.5,1.5,Iris-versicolor\n5.7,2.6,3.5,1.0,Iris-versicolor\n5.5,2.4,3.8,1.1,Iris-versicolor\n5.5,2.4,3.7,1.0,Iris-versicolor\n5.8,2.7,3.9,1.2,Iris-versicolor\n6.0,2.7,5.1,1.6,Iris-versicolor\n5.4,3.0,4.5,1.5,Iris-versicolor\n6.0,3.4,4.5,1.6,Iris-versicolor\n6.7,3.1,4.7,1.5,Iris-versicolor\n6.3,2.3,4.4,1.3,Iris-versicolor\n5.6,3.0,4.1,1.3,Iris-versicolor\n5.5,2.5,4.0,1.3,Iris-versicolor\n5.5,2.6,4.4,1.2,Iris-versicolor\n6.1,3.0,4.6,1.4,Iris-versicolor\n5.8,2.6,4.0,1.2,Iris-versicolor\n5.0,2.3,3.3,1.0,Iris-versicolor\n5.6,2.7,4.2,1.3,Iris-versicolor\n5.7,3.0,4.2,1.2,Iris-versicolor\n5.7,2.9,4.2,1.3,Iris-versicolor\n6.2,2.9,4.3,1.3,Iris-versicolor\n5.1,2.5,3.0,1.1,Iris-versicolor\n5.7,2.8,4.1,1.3,Iris-versicolor\n6.3,3.3,6.0,2.5,Iris-virginica\n5.8,2.7,5.1,1.9,Iris-virginica\n7.1,3.0,5.9,2.1,Iris-virginica\n6.3,2.9,5.6,1.8,Iris-virginica\n6.5,3.0,5.8,2.2,Iris-virginica\n7.6,3.0,6.6,2.1,Iris-virginica\n4.9,2.5,4.5,1.7,Iris-virginica\n7.3,2.9,6.3,1.8,Iris-virginica\n6.7,2.5,5.8,1.8,Iris-virginica\n7.2,3.6,6.1,2.5,Iris-virginica\n6.5,3.2,5.1,2.0,Iris-virginica\n6.4,2.7,5.3,1.9,Iris-virginica\n6.8,3.0,5.5,2.1,Iris-virginica\n5.7,2.5,5.0,2.0,Iris-virginica\n5.8,2.8,5.1,2.4,Iris-virginica\n6.4,3.2,5.3,2.3,Iris-virginica\n6.5,3.0,5.5,1.8,Iris-virginica\n7.7,3.8,6.7,2.2,Iris-virginica\n7.7,2.6,6.9,2.3,Iris-virginica\n6.0,2.2,5.0,1.5,Iris-virginica\n6.9,3.2,5.7,2.3,Iris-virginica\n5.6,2.8,4.9,2.0,Iris-virginica\n7.7,2.8,6.7,2.0,Iris-virginica\n6.3,2.7,4.9,1.8,Iris-virginica\n6.7,3.3,5.7,2.1,Iris-virginica\n7.2,3.2,6.0,1.8,Iris-virginica\n6.2,2.8,4.8,1.8,Iris-virginica\n6.1,3.0,4.9,1.8,Iris-virginica\n6.4,2.8,5.6,2.1,Iris-virginica\n7.2,3.0,5.8,1.6,Iris-virginica\n7.4,2.8,6.1,1.9,Iris-virginica\n7.9,3.8,6.4,2.0,Iris-virginica\n6.4,2.8,5.6,2.2,Iris-virginica\n6.3,2.8,5.1,1.5,Iris-virginica\n6.1,2.6,5.6,1.4,Iris-virginica\n7.7,3.0,6.1,2.3,Iris-virginica\n6.3,3.4,5.6,2.4,Iris-virginica\n6.4,3.1,5.5,1.8,Iris-virginica\n6.0,3.0,4.8,1.8,Iris-virginica\n6.9,3.1,5.4,2.1,Iris-virginica\n6.7,3.1,5.6,2.4,Iris-virginica\n6.9,3.1,5.1,2.3,Iris-virginica\n5.8,2.7,5.1,1.9,Iris-virginica\n6.8,3.2,5.9,2.3,Iris-virginica\n6.7,3.3,5.7,2.5,Iris-virginica\n6.7,3.0,5.2,2.3,Iris-virginica\n6.3,2.5,5.0,1.9,Iris-virginica\n6.5,3.0,5.2,2.0,Iris-virginica\n6.2,3.4,5.4,2.3,Iris-virginica\n5.9,3.0,5.1,1.8,Iris-virginica"
import java.util.*
import java.io.StringReader
val iris = DataFrame.readDelim(StringReader(iris_data)).shuffle()
iris.head()
sepal-length | sepal-width | petal-length | petal-width | species |
---|---|---|---|---|
5.1 | 3.3 | 1.7 | 0.5 | Iris-setosa |
5.8 | 2.7 | 5.1 | 1.9 | Iris-virginica |
5.6 | 2.8 | 4.9 | 2.0 | Iris-virginica |
4.8 | 3.0 | 1.4 | 0.3 | Iris-setosa |
7.7 | 2.6 | 6.9 | 2.3 | Iris-virginica |
val points = geomPoint(
data = mapOf(
"x" to iris["sepal-length"].values().toList(),
"y" to iris["sepal-width"].values().toList(),
"color" to iris["species"].values().toList()
), alpha=1.0)
{
x = "x"
y = "y"
color = "color"
}
ggplot() + points
val irisWithoutLabel = iris.remove("species")
irisWithoutLabel.head()
sepal-length | sepal-width | petal-length | petal-width |
---|---|---|---|
5.1 | 3.3 | 1.7 | 0.5 |
5.8 | 2.7 | 5.1 | 1.9 |
5.6 | 2.8 | 4.9 | 2.0 |
4.8 | 3.0 | 1.4 | 0.3 |
7.7 | 2.6 | 6.9 | 2.3 |
//Convert the iris data into 150x4 matrix
val row = 150
val col = 4
val irisMatrix = Array(row) { DoubleArray(col) }
var i = 0
for (r in 0 until row) {
for (c in 0 until col) {
irisMatrix[r][c] = irisWithoutLabel[c][r] as Double
}
}
println(Arrays.deepToString(irisMatrix).replace("], ", "]\n"))
[[5.1, 3.3, 1.7, 0.5] [5.8, 2.7, 5.1, 1.9] [5.6, 2.8, 4.9, 2.0] [4.8, 3.0, 1.4, 0.3] [7.7, 2.6, 6.9, 2.3] [5.6, 2.9, 3.6, 1.3] [6.9, 3.1, 5.4, 2.1] [5.9, 3.0, 4.2, 1.5] [4.9, 3.1, 1.5, 0.1] [6.8, 2.8, 4.8, 1.4] [6.0, 2.2, 5.0, 1.5] [6.0, 3.4, 4.5, 1.6] [5.4, 3.9, 1.3, 0.4] [5.7, 3.0, 4.2, 1.2] [7.2, 3.0, 5.8, 1.6] [6.0, 2.7, 5.1, 1.6] [6.4, 3.2, 5.3, 2.3] [5.7, 2.8, 4.1, 1.3] [5.7, 2.5, 5.0, 2.0] [6.2, 2.8, 4.8, 1.8] [5.0, 3.5, 1.3, 0.3] [5.7, 4.4, 1.5, 0.4] [6.3, 2.5, 5.0, 1.9] [7.7, 3.0, 6.1, 2.3] [4.8, 3.0, 1.4, 0.1] [5.8, 2.7, 3.9, 1.2] [5.1, 2.5, 3.0, 1.1] [6.4, 2.8, 5.6, 2.1] [5.3, 3.7, 1.5, 0.2] [4.6, 3.4, 1.4, 0.3] [7.6, 3.0, 6.6, 2.1] [4.5, 2.3, 1.3, 0.3] [5.6, 2.7, 4.2, 1.3] [5.7, 2.6, 3.5, 1.0] [6.7, 3.0, 5.0, 1.7] [6.5, 3.0, 5.8, 2.2] [5.0, 2.3, 3.3, 1.0] [6.1, 3.0, 4.9, 1.8] [6.5, 3.0, 5.2, 2.0] [6.2, 3.4, 5.4, 2.3] [4.4, 2.9, 1.4, 0.2] [5.2, 3.5, 1.5, 0.2] [7.2, 3.6, 6.1, 2.5] [5.5, 4.2, 1.4, 0.2] [6.4, 2.9, 4.3, 1.3] [4.9, 3.0, 1.4, 0.2] [6.3, 2.5, 4.9, 1.5] [5.5, 2.4, 3.7, 1.0] [4.7, 3.2, 1.6, 0.2] [6.3, 2.7, 4.9, 1.8] [6.3, 2.3, 4.4, 1.3] [7.1, 3.0, 5.9, 2.1] [5.0, 3.5, 1.6, 0.6] [6.8, 3.0, 5.5, 2.1] [4.8, 3.4, 1.9, 0.2] [6.7, 3.1, 5.6, 2.4] [5.8, 2.6, 4.0, 1.2] [5.0, 3.2, 1.2, 0.2] [6.7, 3.3, 5.7, 2.5] [5.1, 3.5, 1.4, 0.2] [6.4, 2.7, 5.3, 1.9] [7.0, 3.2, 4.7, 1.4] [6.1, 2.8, 4.7, 1.2] [5.4, 3.4, 1.7, 0.2] [4.9, 2.4, 3.3, 1.0] [5.2, 3.4, 1.4, 0.2] [6.5, 2.8, 4.6, 1.5] [5.4, 3.0, 4.5, 1.5] [7.3, 2.9, 6.3, 1.8] [5.2, 2.7, 3.9, 1.4] [5.4, 3.9, 1.7, 0.4] [6.2, 2.2, 4.5, 1.5] [5.1, 3.5, 1.4, 0.3] [4.8, 3.4, 1.6, 0.2] [7.7, 3.8, 6.7, 2.2] [5.6, 3.0, 4.5, 1.5] [6.3, 3.4, 5.6, 2.4] [5.8, 2.8, 5.1, 2.4] [5.5, 2.3, 4.0, 1.3] [4.9, 2.5, 4.5, 1.7] [6.0, 2.2, 4.0, 1.0] [5.0, 2.0, 3.5, 1.0] [5.9, 3.2, 4.8, 1.8] [5.4, 3.4, 1.5, 0.4] [6.9, 3.1, 4.9, 1.5] [4.9, 3.1, 1.5, 0.1] [5.2, 4.1, 1.5, 0.1] [5.1, 3.8, 1.5, 0.3] [5.1, 3.8, 1.6, 0.2] [6.7, 3.1, 4.7, 1.5] [5.9, 3.0, 5.1, 1.8] [5.8, 4.0, 1.2, 0.2] [4.3, 3.0, 1.1, 0.1] [6.7, 2.5, 5.8, 1.8] [6.3, 3.3, 6.0, 2.5] [5.6, 2.5, 3.9, 1.1] [4.4, 3.2, 1.3, 0.2] [4.6, 3.1, 1.5, 0.2] [5.5, 2.6, 4.4, 1.2] [6.9, 3.1, 5.1, 2.3] [6.0, 2.9, 4.5, 1.5] [7.2, 3.2, 6.0, 1.8] [6.1, 2.8, 4.0, 1.3] [5.7, 2.9, 4.2, 1.3] [5.8, 2.7, 4.1, 1.0] [4.8, 3.1, 1.6, 0.2] [6.9, 3.2, 5.7, 2.3] [5.5, 2.4, 3.8, 1.1] [5.0, 3.4, 1.5, 0.2] [4.6, 3.2, 1.4, 0.2] [4.9, 3.1, 1.5, 0.1] [6.0, 3.0, 4.8, 1.8] [6.3, 2.9, 5.6, 1.8] [6.6, 3.0, 4.4, 1.4] [7.9, 3.8, 6.4, 2.0] [5.6, 3.0, 4.1, 1.3] [5.7, 3.8, 1.7, 0.3] [5.0, 3.4, 1.6, 0.4] [5.7, 2.8, 4.5, 1.3] [6.7, 3.3, 5.7, 2.1] [6.7, 3.1, 4.4, 1.4] [6.7, 3.0, 5.2, 2.3] [5.5, 2.5, 4.0, 1.3] [5.0, 3.3, 1.4, 0.2] [4.4, 3.0, 1.3, 0.2] [6.6, 2.9, 4.6, 1.3] [7.4, 2.8, 6.1, 1.9] [6.5, 3.0, 5.5, 1.8] [6.3, 2.8, 5.1, 1.5] [6.4, 3.2, 4.5, 1.5] [6.1, 2.9, 4.7, 1.4] [4.6, 3.6, 1.0, 0.2] [5.4, 3.7, 1.5, 0.2] [5.5, 3.5, 1.3, 0.2] [6.1, 3.0, 4.6, 1.4] [5.8, 2.7, 5.1, 1.9] [6.8, 3.2, 5.9, 2.3] [6.4, 3.1, 5.5, 1.8] [7.7, 2.8, 6.7, 2.0] [5.0, 3.0, 1.6, 0.2] [6.2, 2.9, 4.3, 1.3] [5.1, 3.4, 1.5, 0.2] [6.5, 3.2, 5.1, 2.0] [5.1, 3.7, 1.5, 0.4] [6.4, 2.8, 5.6, 2.2] [5.1, 3.8, 1.9, 0.4] [5.0, 3.6, 1.4, 0.2] [4.7, 3.2, 1.3, 0.2] [6.3, 3.3, 4.7, 1.6] [6.1, 2.6, 5.6, 1.4]]
//Now do the same for the label data
val irisLabel = iris.select("species")[0]
val rowLabel = 150
val colLabel = 3
val twodimLabel = Array(rowLabel) { DoubleArray(colLabel) }
for (r in 0 until rowLabel) {
when (irisLabel[r]) {
"Iris-setosa" -> twodimLabel[r][0] = 1.0
"Iris-versicolor" -> twodimLabel[r][1] = 1.0
"Iris-virginica" -> twodimLabel[r][2] = 1.0
}
}
println(Arrays.deepToString(twodimLabel).replace("], ", "]\n"))
[[1.0, 0.0, 0.0] [0.0, 0.0, 1.0] [0.0, 0.0, 1.0] [1.0, 0.0, 0.0] [0.0, 0.0, 1.0] [0.0, 1.0, 0.0] [0.0, 0.0, 1.0] [0.0, 1.0, 0.0] [1.0, 0.0, 0.0] [0.0, 1.0, 0.0] [0.0, 0.0, 1.0] [0.0, 1.0, 0.0] [1.0, 0.0, 0.0] [0.0, 1.0, 0.0] [0.0, 0.0, 1.0] [0.0, 1.0, 0.0] [0.0, 0.0, 1.0] [0.0, 1.0, 0.0] [0.0, 0.0, 1.0] [0.0, 0.0, 1.0] [1.0, 0.0, 0.0] [1.0, 0.0, 0.0] [0.0, 0.0, 1.0] [0.0, 0.0, 1.0] [1.0, 0.0, 0.0] [0.0, 1.0, 0.0] [0.0, 1.0, 0.0] [0.0, 0.0, 1.0] [1.0, 0.0, 0.0] [1.0, 0.0, 0.0] [0.0, 0.0, 1.0] [1.0, 0.0, 0.0] [0.0, 1.0, 0.0] [0.0, 1.0, 0.0] [0.0, 1.0, 0.0] [0.0, 0.0, 1.0] [0.0, 1.0, 0.0] [0.0, 0.0, 1.0] [0.0, 0.0, 1.0] [0.0, 0.0, 1.0] [1.0, 0.0, 0.0] [1.0, 0.0, 0.0] [0.0, 0.0, 1.0] [1.0, 0.0, 0.0] [0.0, 1.0, 0.0] [1.0, 0.0, 0.0] [0.0, 1.0, 0.0] [0.0, 1.0, 0.0] [1.0, 0.0, 0.0] [0.0, 0.0, 1.0] [0.0, 1.0, 0.0] [0.0, 0.0, 1.0] [1.0, 0.0, 0.0] [0.0, 0.0, 1.0] [1.0, 0.0, 0.0] [0.0, 0.0, 1.0] [0.0, 1.0, 0.0] [1.0, 0.0, 0.0] [0.0, 0.0, 1.0] [1.0, 0.0, 0.0] [0.0, 0.0, 1.0] [0.0, 1.0, 0.0] [0.0, 1.0, 0.0] [1.0, 0.0, 0.0] [0.0, 1.0, 0.0] [1.0, 0.0, 0.0] [0.0, 1.0, 0.0] [0.0, 1.0, 0.0] [0.0, 0.0, 1.0] [0.0, 1.0, 0.0] [1.0, 0.0, 0.0] [0.0, 1.0, 0.0] [1.0, 0.0, 0.0] [1.0, 0.0, 0.0] [0.0, 0.0, 1.0] [0.0, 1.0, 0.0] [0.0, 0.0, 1.0] [0.0, 0.0, 1.0] [0.0, 1.0, 0.0] [0.0, 0.0, 1.0] [0.0, 1.0, 0.0] [0.0, 1.0, 0.0] [0.0, 1.0, 0.0] [1.0, 0.0, 0.0] [0.0, 1.0, 0.0] [1.0, 0.0, 0.0] [1.0, 0.0, 0.0] [1.0, 0.0, 0.0] [1.0, 0.0, 0.0] [0.0, 1.0, 0.0] [0.0, 0.0, 1.0] [1.0, 0.0, 0.0] [1.0, 0.0, 0.0] [0.0, 0.0, 1.0] [0.0, 0.0, 1.0] [0.0, 1.0, 0.0] [1.0, 0.0, 0.0] [1.0, 0.0, 0.0] [0.0, 1.0, 0.0] [0.0, 0.0, 1.0] [0.0, 1.0, 0.0] [0.0, 0.0, 1.0] [0.0, 1.0, 0.0] [0.0, 1.0, 0.0] [0.0, 1.0, 0.0] [1.0, 0.0, 0.0] [0.0, 0.0, 1.0] [0.0, 1.0, 0.0] [1.0, 0.0, 0.0] [1.0, 0.0, 0.0] [1.0, 0.0, 0.0] [0.0, 0.0, 1.0] [0.0, 0.0, 1.0] [0.0, 1.0, 0.0] [0.0, 0.0, 1.0] [0.0, 1.0, 0.0] [1.0, 0.0, 0.0] [1.0, 0.0, 0.0] [0.0, 1.0, 0.0] [0.0, 0.0, 1.0] [0.0, 1.0, 0.0] [0.0, 0.0, 1.0] [0.0, 1.0, 0.0] [1.0, 0.0, 0.0] [1.0, 0.0, 0.0] [0.0, 1.0, 0.0] [0.0, 0.0, 1.0] [0.0, 0.0, 1.0] [0.0, 0.0, 1.0] [0.0, 1.0, 0.0] [0.0, 1.0, 0.0] [1.0, 0.0, 0.0] [1.0, 0.0, 0.0] [1.0, 0.0, 0.0] [0.0, 1.0, 0.0] [0.0, 0.0, 1.0] [0.0, 0.0, 1.0] [0.0, 0.0, 1.0] [0.0, 0.0, 1.0] [1.0, 0.0, 0.0] [0.0, 1.0, 0.0] [1.0, 0.0, 0.0] [0.0, 0.0, 1.0] [1.0, 0.0, 0.0] [0.0, 0.0, 1.0] [1.0, 0.0, 0.0] [1.0, 0.0, 0.0] [1.0, 0.0, 0.0] [0.0, 1.0, 0.0] [0.0, 0.0, 1.0]]
//Convert the data matrices into training INDArrays
val dataIn = Nd4j.create(irisMatrix)
val dataOut = Nd4j.create(twodimLabel)
import org.nd4j.linalg.lossfunctions.LossFunctions
val seed: Long = 6
val conf = NeuralNetConfiguration.Builder()
.seed(seed) //include a random seed for reproducibility
// use stochastic gradient descent as an optimization algorithm
.updater(Nadam()) //specify the rate of change of the learning rate.
.l2(1e-4)
.list()
.layer(DenseLayer.Builder()
.nIn(4)
.nOut(3)
.activation(Activation.TANH)
.weightInit(WeightInit.XAVIER)
.build())
.layer(org.deeplearning4j.nn.conf.layers.DenseLayer.Builder()
.nIn(3)
.nOut(3)
.build())
.layer(OutputLayer.Builder(LossFunctions.LossFunction.NEGATIVELOGLIKELIHOOD)
.nIn(3)
.nOut(3)
.activation(Activation.SOFTMAX)
.weightInit(WeightInit.XAVIER)
.build())
.build()
val model = MultiLayerNetwork(conf)
import org.nd4j.linalg.dataset.api.preprocessor.DataNormalization
import org.nd4j.linalg.dataset.api.preprocessor.NormalizerStandardize
//Create a data set from the INDArrays and shuffle it
val fullDataSet = DataSet(dataIn, dataOut)
fullDataSet.shuffle(seed)
val splitedSet = fullDataSet.splitTestAndTrain(0.90)
val trainingData = splitedSet.train;
val testData = splitedSet.test;
//We need to normalize our data. We'll use NormalizeStandardize (which gives us mean 0, unit variance):
val normalizer: DataNormalization = NormalizerStandardize()
normalizer.fit(trainingData) //Collect the statistics (mean/stdev) from the training data. This does not modify the input data
normalizer.transform(trainingData) //Apply normalization to the training data
normalizer.transform(testData) //Apply normalization to the test data. This is using statistics calculated from the *training* set
// train the network
model.setListeners(ScoreIterationListener(100))
for (l in 0..2000) {
model.fit(trainingData)
}
// evaluate the network
val eval = Evaluation()
val output: INDArray = model.output(testData.features)
eval.eval(testData.labels, output)
println("Score " + eval.stats())
Score ========================Evaluation Metrics======================== # of classes: 3 Accuracy: 1,0000 Precision: 1,0000 Recall: 1,0000 F1 Score: 1,0000 Precision, recall & F1: macro-averaged (equally weighted avg. of 3 classes) =========================Confusion Matrix========================= 0 1 2 ------- 5 0 0 | 0 = 0 0 3 0 | 1 = 1 0 0 7 | 2 = 2 Confusion matrix format: Actual (rowClass) predicted as (columnClass) N times ==================================================================